From ba7534f274c4231d5e79687845ae97887ba81179 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 15 Sep 2025 16:12:28 +0530 Subject: [PATCH 01/11] fix: jwt validation blocks on multiple requests --- router-tests/authentication_test.go | 45 +++++++++++++++++-- router-tests/header_set_test.go | 2 +- router-tests/jwks/jwks.go | 12 ++++- router-tests/testenv/testenv.go | 21 +++++++++ .../pkg/authentication/jwks_token_decoder.go | 7 +-- 5 files changed, 75 insertions(+), 12 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 38d2715a56..561e36a8a9 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -828,6 +828,43 @@ func TestHttpJwksAuthorization(t *testing.T) { }) }) + t.Run("authentication should not block with an invalid token on multiple calls", func(t *testing.T) { + t.Parallel() + + authenticators, authServer := ConfigureAuth(t) + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // This token has a static keyid + token, err := authServer.TokenForKID("wg_static_kid", nil, true) + maxDuration := 5 * time.Second + + xEnv.WaitForTest(maxDuration, func() { + require.NoError(t, err) + for range 5 { + func() { + // Operations with an invalid token should fail + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { + _ = res.Body.Close() + }() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }() + } + }) + }) + }) + } func TestNonHttpAuthorization(t *testing.T) { @@ -1182,7 +1219,7 @@ func TestAlgorithmMismatch(t *testing.T) { authenticators := []authentication.Authenticator{authenticator} - token, err := authServer.TokenForKID(crypto.KID(), nil) + token, err := authServer.TokenForKID(crypto.KID(), nil, false) require.NoError(t, err) return token, authenticators @@ -1307,7 +1344,7 @@ func TestOidcDiscovery(t *testing.T) { tokens := make(map[string]string) for _, c := range crypto { - token, err := authServer.TokenForKID(c.KID(), nil) + token, err := authServer.TokenForKID(c.KID(), nil, false) require.NoError(t, err) tokens[c.KID()] = token @@ -1421,7 +1458,7 @@ func TestMultipleKeys(t *testing.T) { tokens := make(map[string]string) for _, c := range crypto { - token, err := authServer.TokenForKID(c.KID(), nil) + token, err := authServer.TokenForKID(c.KID(), nil, false) require.NoError(t, err) tokens[c.KID()] = token @@ -1604,7 +1641,7 @@ func TestSupportedAlgorithms(t *testing.T) { authenticators := []authentication.Authenticator{authenticator} - token, err := authServer.TokenForKID(crypto.KID(), nil) + token, err := authServer.TokenForKID(crypto.KID(), nil, false) require.NoError(t, err) return token, authenticators diff --git a/router-tests/header_set_test.go b/router-tests/header_set_test.go index 763663e2b5..a8f15d69af 100644 --- a/router-tests/header_set_test.go +++ b/router-tests/header_set_test.go @@ -275,7 +275,7 @@ func TestHeaderSetWithExpression(t *testing.T) { authenticator, err := authentication.NewHttpHeaderAuthenticator(authOptions) require.NoError(t, err) - token, err := authServer.TokenForKID(rsa1.KID(), map[string]any{"user_id": "TestId"}) + token, err := authServer.TokenForKID(rsa1.KID(), map[string]any{"user_id": "TestId"}, false) require.NoError(t, err) testenv.Run(t, &testenv.Config{ diff --git a/router-tests/jwks/jwks.go b/router-tests/jwks/jwks.go index 080879aa0c..78a414ab9b 100644 --- a/router-tests/jwks/jwks.go +++ b/router-tests/jwks/jwks.go @@ -50,11 +50,19 @@ func (s *Server) Token(claims map[string]any) (string, error) { return "", jwt.ErrInvalidKey } -func (s *Server) TokenForKID(kid string, claims map[string]any) (string, error) { +func (s *Server) TokenForKID(kid string, claims map[string]any, useInvalidKid bool) (string, error) { provider, ok := s.providers[kid] - if !ok { + if !useInvalidKid && !ok { return "", jwt.ErrInvalidKey + } else if useInvalidKid { + // If we don't care about the kid we don't care about the provider + // we just get the first random provider provided + for _, pr := range s.providers { + provider = pr + break + } } + token := jwt.NewWithClaims(provider.SigningMethod(), jwt.MapClaims(claims)) token.Header[jwkset.HeaderKID] = kid return token.SignedString(provider.PrivateKey()) diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index bc6cf1df84..30cc7929bf 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -2532,6 +2532,27 @@ func (e *Environment) WaitForConnectionCount(desiredCount uint64, timeout time.D } } +func (e *Environment) WaitForTest(timeout time.Duration, testFunction func()) { + e.t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + done := make(chan struct{}) + go func() { + defer close(done) + testFunction() + }() + + select { + case <-done: + return + case <-ctx.Done(): + e.t.Fatalf("test timed out, want %d", timeout) + return + } +} + type EngineStatisticAssertion struct { Subscriptions int64 Connections int64 diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index 53b252632c..cae860409f 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -10,10 +10,8 @@ import ( "github.com/MicahParks/jwkset" "github.com/MicahParks/keyfunc/v3" "github.com/golang-jwt/jwt/v5" - "go.uber.org/zap" - "golang.org/x/time/rate" - "github.com/wundergraph/cosmo/router/internal/httpclient" + "go.uber.org/zap" ) type TokenDecoder interface { @@ -95,8 +93,7 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS HTTPURLs: map[string]jwkset.Storage{ c.URL: store, }, - PrioritizeHTTP: true, - RefreshUnknownKID: rate.NewLimiter(rate.Every(5*time.Minute), 1), + PrioritizeHTTP: true, } jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions) From 79adec8d1b51bd4ab61d4a971cb01a29c2a44510 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 17:55:10 +0530 Subject: [PATCH 02/11] fix: review comments --- router-tests/authentication_test.go | 11 ++++++++--- router-tests/jwks/jwks.go | 6 +++--- router/go.mod | 2 +- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 561e36a8a9..0e6f006285 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -839,10 +839,13 @@ func TestHttpJwksAuthorization(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { // This token has a static keyid token, err := authServer.TokenForKID("wg_static_kid", nil, true) + require.NoError(t, err) + maxDuration := 5 * time.Second - xEnv.WaitForTest(maxDuration, func() { - require.NoError(t, err) + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) for range 5 { func() { // Operations with an invalid token should fail @@ -861,7 +864,9 @@ func TestHttpJwksAuthorization(t *testing.T) { require.JSONEq(t, unauthorizedExpectedData, string(data)) }() } - }) + }() + + testenv.AwaitChannelWithT(t, maxDuration, doneCh, func(t *testing.T, _ struct{}) {}, "test timed out") }) }) diff --git a/router-tests/jwks/jwks.go b/router-tests/jwks/jwks.go index 78a414ab9b..07e0ada272 100644 --- a/router-tests/jwks/jwks.go +++ b/router-tests/jwks/jwks.go @@ -50,11 +50,11 @@ func (s *Server) Token(claims map[string]any) (string, error) { return "", jwt.ErrInvalidKey } -func (s *Server) TokenForKID(kid string, claims map[string]any, useInvalidKid bool) (string, error) { +func (s *Server) TokenForKID(kid string, claims map[string]any, useInvalidKID bool) (string, error) { provider, ok := s.providers[kid] - if !useInvalidKid && !ok { + if !useInvalidKID && !ok { return "", jwt.ErrInvalidKey - } else if useInvalidKid { + } else if useInvalidKID { // If we don't care about the kid we don't care about the provider // we just get the first random provider provided for _, pr := range s.providers { diff --git a/router/go.mod b/router/go.mod index ae51779a4f..dc8761f12e 100644 --- a/router/go.mod +++ b/router/go.mod @@ -83,7 +83,6 @@ require ( go.uber.org/ratelimit v0.3.1 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 golang.org/x/text v0.23.0 - golang.org/x/time v0.9.0 ) require ( @@ -166,6 +165,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/net v0.38.0 // indirect + golang.org/x/time v0.9.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect From ca0aaaca2130254aeda27a1aafab0d587c26ada6 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 17:59:58 +0530 Subject: [PATCH 03/11] fix: cleanup --- router-tests/testenv/testenv.go | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 30cc7929bf..bc6cf1df84 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -2532,27 +2532,6 @@ func (e *Environment) WaitForConnectionCount(desiredCount uint64, timeout time.D } } -func (e *Environment) WaitForTest(timeout time.Duration, testFunction func()) { - e.t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - done := make(chan struct{}) - go func() { - defer close(done) - testFunction() - }() - - select { - case <-done: - return - case <-ctx.Done(): - e.t.Fatalf("test timed out, want %d", timeout) - return - } -} - type EngineStatisticAssertion struct { Subscriptions int64 Connections int64 From e0f1e53e2e20843d4cf505d4b947fdbaf4f91957 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 18:45:01 +0530 Subject: [PATCH 04/11] fix: require equals --- router-tests/authentication_test.go | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 0e6f006285..06b2beb49b 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -843,9 +843,7 @@ func TestHttpJwksAuthorization(t *testing.T) { maxDuration := 5 * time.Second - doneCh := make(chan struct{}) - go func() { - defer close(doneCh) + require.Eventually(t, func() bool { for range 5 { func() { // Operations with an invalid token should fail @@ -864,9 +862,8 @@ func TestHttpJwksAuthorization(t *testing.T) { require.JSONEq(t, unauthorizedExpectedData, string(data)) }() } - }() - - testenv.AwaitChannelWithT(t, maxDuration, doneCh, func(t *testing.T, _ struct{}) {}, "test timed out") + return true + }, maxDuration, 10*time.Millisecond) }) }) From 3166aa801ca5bd28051e695d207023102418cc6e Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 18:47:32 +0530 Subject: [PATCH 05/11] Revert "fix: require equals" This reverts commit e0f1e53e2e20843d4cf505d4b947fdbaf4f91957. --- router-tests/authentication_test.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 06b2beb49b..0e6f006285 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -843,7 +843,9 @@ func TestHttpJwksAuthorization(t *testing.T) { maxDuration := 5 * time.Second - require.Eventually(t, func() bool { + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) for range 5 { func() { // Operations with an invalid token should fail @@ -862,8 +864,9 @@ func TestHttpJwksAuthorization(t *testing.T) { require.JSONEq(t, unauthorizedExpectedData, string(data)) }() } - return true - }, maxDuration, 10*time.Millisecond) + }() + + testenv.AwaitChannelWithT(t, maxDuration, doneCh, func(t *testing.T, _ struct{}) {}, "test timed out") }) }) From 3e9ff2aaf4a8a82a5a8b433d8ee1c3b61907f221 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 19:01:30 +0530 Subject: [PATCH 06/11] fix: tests --- router-tests/authentication_test.go | 8 ++------ router-tests/testenv/utils.go | 12 ++++++++++++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 0e6f006285..9a41fda9b0 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -843,9 +843,7 @@ func TestHttpJwksAuthorization(t *testing.T) { maxDuration := 5 * time.Second - doneCh := make(chan struct{}) - go func() { - defer close(doneCh) + testenv.AwaitFunc(t, maxDuration, func() { for range 5 { func() { // Operations with an invalid token should fail @@ -864,9 +862,7 @@ func TestHttpJwksAuthorization(t *testing.T) { require.JSONEq(t, unauthorizedExpectedData, string(data)) }() } - }() - - testenv.AwaitChannelWithT(t, maxDuration, doneCh, func(t *testing.T, _ struct{}) {}, "test timed out") + }) }) }) diff --git a/router-tests/testenv/utils.go b/router-tests/testenv/utils.go index bd4f1842db..19d13ba27c 100644 --- a/router-tests/testenv/utils.go +++ b/router-tests/testenv/utils.go @@ -28,3 +28,15 @@ func AwaitChannelWithCloseWithT[A any](t *testing.T, timeout time.Duration, ch < require.Fail(t, "unable to receive message before timeout", msgAndArgs...) } } + +func AwaitFunc(t *testing.T, timeout time.Duration, testFunction func()) { + t.Helper() + + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + testFunction() + }() + + AwaitChannelWithT(t, timeout, doneCh, func(t *testing.T, _ struct{}) {}, "the test function timed out") +} From 6978e0d52bdff4335ca4709204249a49d06405bf Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Fri, 19 Sep 2025 00:46:31 +0530 Subject: [PATCH 07/11] fix: tests --- router-tests/jwks/jwks.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/router-tests/jwks/jwks.go b/router-tests/jwks/jwks.go index 07e0ada272..6b77a76812 100644 --- a/router-tests/jwks/jwks.go +++ b/router-tests/jwks/jwks.go @@ -52,15 +52,14 @@ func (s *Server) Token(claims map[string]any) (string, error) { func (s *Server) TokenForKID(kid string, claims map[string]any, useInvalidKID bool) (string, error) { provider, ok := s.providers[kid] - if !useInvalidKID && !ok { - return "", jwt.ErrInvalidKey - } else if useInvalidKID { - // If we don't care about the kid we don't care about the provider - // we just get the first random provider provided + if useInvalidKID { + // If we don't care about the kid, use any available provider for _, pr := range s.providers { provider = pr break } + } else if !ok { + return "", jwt.ErrInvalidKey } token := jwt.NewWithClaims(provider.SigningMethod(), jwt.MapClaims(claims)) From a04b95323b2371fba41a8eddcdf5ed949b60ca59 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 01:48:40 +0530 Subject: [PATCH 08/11] fix: make rate limit values configurable --- router-tests/authentication_test.go | 183 ++++++++++++++---- router/core/supervisor_instance.go | 5 + .../pkg/authentication/jwks_token_decoder.go | 19 +- router/pkg/config/config.go | 13 +- router/pkg/config/config.schema.json | 26 +++ router/pkg/config/testdata/config_full.json | 15 ++ 6 files changed, 219 insertions(+), 42 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 9a41fda9b0..ab8650c381 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -56,6 +56,151 @@ func TestAuthentication(t *testing.T) { }) }) + t.Run("unknown kid refresh blocks when burst exceeded", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, + Burst: 1, + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + header := http.Header{"Authorization": []string{"Bearer " + token}} + + res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res1.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res1.StatusCode) + _, err = io.ReadAll(res1.Body) + require.NoError(t, err) + + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + elapsed := time.Since(start) + + require.True(t, elapsed >= 600*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + + t.Run("unknown kid refresh does not block when burst not exceeded", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, + Burst: 1, + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + header := http.Header{"Authorization": []string{"Bearer " + token}} + + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + _, err = io.ReadAll(res.Body) + require.NoError(t, err) + + // Wait for interval so next refresh is within burst budget + time.Sleep(1200 * time.Millisecond) + + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + elapsed := time.Since(start) + require.True(t, elapsed < 100*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + + t.Run("authentication should not block with unknown kid when refresh is disabled", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 100 * time.Millisecond, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Create a token signed with a valid key but with an unknown kid header + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + maxDuration := 4 * time.Second + testenv.AwaitFunc(t, maxDuration, func() { + for range 5 { + func() { + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }() + } + }) + }) + }) + t.Run("invalid token", func(t *testing.T) { t.Parallel() @@ -828,44 +973,6 @@ func TestHttpJwksAuthorization(t *testing.T) { }) }) - t.Run("authentication should not block with an invalid token on multiple calls", func(t *testing.T) { - t.Parallel() - - authenticators, authServer := ConfigureAuth(t) - testenv.Run(t, &testenv.Config{ - RouterOptions: []core.Option{ - core.WithAccessController(core.NewAccessController(authenticators, true)), - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - // This token has a static keyid - token, err := authServer.TokenForKID("wg_static_kid", nil, true) - require.NoError(t, err) - - maxDuration := 5 * time.Second - - testenv.AwaitFunc(t, maxDuration, func() { - for range 5 { - func() { - // Operations with an invalid token should fail - header := http.Header{ - "Authorization": []string{"Bearer " + token}, - } - res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) - require.NoError(t, err) - defer func() { - _ = res.Body.Close() - }() - require.Equal(t, http.StatusUnauthorized, res.StatusCode) - require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader)) - data, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.JSONEq(t, unauthorizedExpectedData, string(data)) - }() - } - }) - }) - }) - } func TestNonHttpAuthorization(t *testing.T) { diff --git a/router/core/supervisor_instance.go b/router/core/supervisor_instance.go index 09605f6d9e..b2ca6050c1 100644 --- a/router/core/supervisor_instance.go +++ b/router/core/supervisor_instance.go @@ -265,6 +265,11 @@ func setupAuthenticators(ctx context.Context, logger *zap.Logger, cfg *config.Co KeyId: jwks.KeyId, Audiences: jwks.Audiences, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: jwks.RefreshUnknownKID.Enabled, + Interval: jwks.RefreshUnknownKID.Interval, + Burst: jwks.RefreshUnknownKID.Burst, + }, }) } diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index cae860409f..0cc06ce962 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -7,6 +7,8 @@ import ( "net/http" "time" + "golang.org/x/time/rate" + "github.com/MicahParks/jwkset" "github.com/MicahParks/keyfunc/v3" "github.com/golang-jwt/jwt/v5" @@ -47,6 +49,14 @@ type JWKSConfig struct { KeyId string Audiences []string + + RefreshUnknownKID RefreshUnknownKIDConfig +} + +type RefreshUnknownKIDConfig struct { + Enabled bool + Interval time.Duration + Burst int } type audKey struct { @@ -89,11 +99,18 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS audiencesMap[key] = getAudienceSet(c.Audiences) + // Configure the rate limiter for refreshing unknown KIDs + var refreshLimiter *rate.Limiter + if c.RefreshUnknownKID.Enabled { + refreshLimiter = rate.NewLimiter(rate.Every(c.RefreshUnknownKID.Interval), c.RefreshUnknownKID.Burst) + } + jwksetHTTPClientOptions := jwkset.HTTPClientOptions{ HTTPURLs: map[string]jwkset.Storage{ c.URL: store, }, - PrioritizeHTTP: true, + PrioritizeHTTP: true, + RefreshUnknownKID: refreshLimiter, } jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions) diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index c33310d657..9a87357953 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -466,9 +466,10 @@ type OverridesConfiguration struct { } type JWKSConfiguration struct { - URL string `yaml:"url"` - Algorithms []string `yaml:"algorithms"` - RefreshInterval time.Duration `yaml:"refresh_interval" envDefault:"1m"` + URL string `yaml:"url"` + Algorithms []string `yaml:"algorithms"` + RefreshInterval time.Duration `yaml:"refresh_interval" envDefault:"1m"` + RefreshUnknownKID RefreshUnknownKIDConfig `yaml:",inline"` // For secret based where we need to create a jwk entry with // a key id and algorithm @@ -480,6 +481,12 @@ type JWKSConfiguration struct { Audiences []string `yaml:"audiences"` } +type RefreshUnknownKIDConfig struct { + Enabled bool `yaml:"enabled" envDefault:"false"` + Interval time.Duration `yaml:"interval" envDefault:"1m"` + Burst int `yaml:"burst" envDefault:"2"` +} + type HeaderSource struct { Type string `yaml:"type"` Name string `yaml:"name"` diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index c45890172e..0bdd4b323c 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1716,6 +1716,32 @@ }, "description": "The interval at which the JWKs are refreshed. The period is specified as a string with a number and a unit, e.g. 10ms, 1s, 1m, 1h. The supported units are 'ms', 's', 'm', 'h'.", "default": "1m" + }, + "refresh_unknown_kid": { + "type": "object", + "description": "Controls rate-limited refresh behavior when a JWT KID is unknown.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable refresh attempts on unknown KID.", + "default": false + }, + "interval": { + "type": "string", + "description": "Token refill interval for the rate limiter.", + "default": "1m", + "duration": { + "minimum": "1s" + } + }, + "burst": { + "type": "integer", + "description": "Burst size for the rate limiter.", + "default": 2, + "minimum": 1 + } + } } }, "oneOf": [ diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index ec8b47da50..ddff7fb15d 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -480,6 +480,11 @@ "RS256" ], "RefreshInterval": 60000000000, + "RefreshUnknownKID": { + "Enabled": false, + "Interval": 0, + "Burst": 0 + }, "Secret": "", "Algorithm": "", "KeyId": "", @@ -492,6 +497,11 @@ "ES256" ], "RefreshInterval": 120000000000, + "RefreshUnknownKID": { + "Enabled": false, + "Interval": 0, + "Burst": 0 + }, "Secret": "", "Algorithm": "", "KeyId": "", @@ -501,6 +511,11 @@ "URL": "https://example.com/.well-known/jwks3.json", "Algorithms": null, "RefreshInterval": 0, + "RefreshUnknownKID": { + "Enabled": false, + "Interval": 0, + "Burst": 0 + }, "Secret": "", "Algorithm": "", "KeyId": "", From 74ac0c5127d3df06d55868944459809fadc9bcec Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 02:12:20 +0530 Subject: [PATCH 09/11] fix: changes --- router-tests/authentication_test.go | 109 +++++++++++++++++- router/core/supervisor_instance.go | 1 + router/go.mod | 2 +- .../pkg/authentication/jwks_token_decoder.go | 16 +-- router/pkg/config/config.go | 11 +- router/pkg/config/config.schema.json | 8 ++ router/pkg/config/fixtures/full.yaml | 5 + router/pkg/config/testdata/config_full.json | 9 +- 8 files changed, 143 insertions(+), 18 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index ab8650c381..0d7fbcb37f 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -98,7 +98,7 @@ func TestAuthentication(t *testing.T) { defer func() { _ = res2.Body.Close() }() elapsed := time.Since(start) - require.True(t, elapsed >= 600*time.Millisecond) + require.True(t, elapsed >= 700*time.Millisecond) require.Equal(t, http.StatusUnauthorized, res2.StatusCode) data, err := io.ReadAll(res2.Body) require.NoError(t, err) @@ -157,6 +157,113 @@ func TestAuthentication(t *testing.T) { }) }) + // Since the rate limiter knows that the limit will definitely be exceeded it exits + // immediately without waiting + t.Run("unknown kid refresh interval exceeding max wait returns immediately", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, // next token available in ~1s + Burst: 1, + MaxWait: 700 * time.Millisecond, // cap wait well below interval + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + header := http.Header{"Authorization": []string{"Bearer " + token}} + + // Consume burst token + res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res1.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res1.StatusCode) + _, err = io.ReadAll(res1.Body) + require.NoError(t, err) + + // Next call should block but only up to MaxWait (< interval) + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + elapsed := time.Since(start) + require.True(t, elapsed < 100*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + + t.Run("unknown kid refresh exceeding burst waits until interval when MaxWait larger", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, + Burst: 1, + MaxWait: 2 * time.Second, // larger than interval, so it can wait until next token + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + header := http.Header{"Authorization": []string{"Bearer " + token}} + + // Consume burst token + res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res1.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res1.StatusCode) + _, err = io.ReadAll(res1.Body) + require.NoError(t, err) + + // Next call should wait until limiter allows next refresh (~1s) + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + elapsed := time.Since(start) + + require.True(t, elapsed >= 600*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + t.Run("authentication should not block with unknown kid when refresh is disabled", func(t *testing.T) { t.Parallel() diff --git a/router/core/supervisor_instance.go b/router/core/supervisor_instance.go index b2ca6050c1..7a43d3f138 100644 --- a/router/core/supervisor_instance.go +++ b/router/core/supervisor_instance.go @@ -267,6 +267,7 @@ func setupAuthenticators(ctx context.Context, logger *zap.Logger, cfg *config.Co Audiences: jwks.Audiences, RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ Enabled: jwks.RefreshUnknownKID.Enabled, + MaxWait: jwks.RefreshUnknownKID.MaxWait, Interval: jwks.RefreshUnknownKID.Interval, Burst: jwks.RefreshUnknownKID.Burst, }, diff --git a/router/go.mod b/router/go.mod index dc8761f12e..ae51779a4f 100644 --- a/router/go.mod +++ b/router/go.mod @@ -83,6 +83,7 @@ require ( go.uber.org/ratelimit v0.3.1 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 golang.org/x/text v0.23.0 + golang.org/x/time v0.9.0 ) require ( @@ -165,7 +166,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/net v0.38.0 // indirect - golang.org/x/time v0.9.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index 0cc06ce962..c6ae6c794e 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -57,6 +57,7 @@ type RefreshUnknownKIDConfig struct { Enabled bool Interval time.Duration Burst int + MaxWait time.Duration } type audKey struct { @@ -99,18 +100,17 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS audiencesMap[key] = getAudienceSet(c.Audiences) - // Configure the rate limiter for refreshing unknown KIDs - var refreshLimiter *rate.Limiter - if c.RefreshUnknownKID.Enabled { - refreshLimiter = rate.NewLimiter(rate.Every(c.RefreshUnknownKID.Interval), c.RefreshUnknownKID.Burst) - } - jwksetHTTPClientOptions := jwkset.HTTPClientOptions{ HTTPURLs: map[string]jwkset.Storage{ c.URL: store, }, - PrioritizeHTTP: true, - RefreshUnknownKID: refreshLimiter, + PrioritizeHTTP: true, + } + + // Configure the rate limiter for refreshing unknown KIDs + if c.RefreshUnknownKID.Enabled { + jwksetHTTPClientOptions.RefreshUnknownKID = rate.NewLimiter(rate.Every(c.RefreshUnknownKID.Interval), c.RefreshUnknownKID.Burst) + jwksetHTTPClientOptions.RateLimitWaitMax = c.RefreshUnknownKID.MaxWait } jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions) diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 9a87357953..8f410e6700 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -466,10 +466,10 @@ type OverridesConfiguration struct { } type JWKSConfiguration struct { - URL string `yaml:"url"` - Algorithms []string `yaml:"algorithms"` - RefreshInterval time.Duration `yaml:"refresh_interval" envDefault:"1m"` - RefreshUnknownKID RefreshUnknownKIDConfig `yaml:",inline"` + URL string `yaml:"url"` + Algorithms []string `yaml:"algorithms"` + RefreshInterval time.Duration `yaml:"refresh_interval" envDefault:"1m"` + RefreshUnknownKID RefreshUnknownKID `yaml:"refresh_unknown_kid"` // For secret based where we need to create a jwk entry with // a key id and algorithm @@ -481,8 +481,9 @@ type JWKSConfiguration struct { Audiences []string `yaml:"audiences"` } -type RefreshUnknownKIDConfig struct { +type RefreshUnknownKID struct { Enabled bool `yaml:"enabled" envDefault:"false"` + MaxWait time.Duration `yaml:"max_wait" envDefault:"10s"` Interval time.Duration `yaml:"interval" envDefault:"1m"` Burst int `yaml:"burst" envDefault:"2"` } diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 0bdd4b323c..570d5f8086 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1727,6 +1727,14 @@ "description": "Enable refresh attempts on unknown KID.", "default": false }, + "max_wait": { + "type": "string", + "description": "Maximum time to wait for a refresh permit before giving up.", + "default": "10s", + "duration": { + "minimum": "0s" + } + }, "interval": { "type": "string", "description": "Token refill interval for the rate limiter.", diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index afd160f911..a43691cc12 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -275,6 +275,11 @@ authentication: - url: 'https://example.com/.well-known/jwks2.json' refresh_interval: 2m algorithms: ['RS256', 'ES256'] + refresh_unknown_kid: + enabled: true + max_wait: 10s + interval: 5s + burst: 3 - url: 'https://example.com/.well-known/jwks3.json' header_name: Authorization header_value_prefix: Bearer diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index ddff7fb15d..ba295c3935 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -482,6 +482,7 @@ "RefreshInterval": 60000000000, "RefreshUnknownKID": { "Enabled": false, + "MaxWait": 0, "Interval": 0, "Burst": 0 }, @@ -498,9 +499,10 @@ ], "RefreshInterval": 120000000000, "RefreshUnknownKID": { - "Enabled": false, - "Interval": 0, - "Burst": 0 + "Enabled": true, + "MaxWait": 10000000000, + "Interval": 5000000000, + "Burst": 3 }, "Secret": "", "Algorithm": "", @@ -513,6 +515,7 @@ "RefreshInterval": 0, "RefreshUnknownKID": { "Enabled": false, + "MaxWait": 0, "Interval": 0, "Burst": 0 }, From fd38a250e6e7dd9f1753d56f2773d1173cc72ead Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 02:43:45 +0530 Subject: [PATCH 10/11] fix: tests --- router-tests/authentication_test.go | 80 +++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 0d7fbcb37f..a7689892a3 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -8,6 +8,8 @@ import ( "io" "net/http" "strings" + "sync" + "sync/atomic" "testing" "time" @@ -211,6 +213,84 @@ func TestAuthentication(t *testing.T) { }) }) + // After consuming the single burst token, launch multiple requests in parallel. + // Each should block if the max limit has not been accumulated + t.Run("unknown kid refresh parallel exceeding burst waits up to MaxWait", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + const waitEntries = 4 + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, + Burst: 1, + MaxWait: waitEntries * time.Second, + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + header := http.Header{"Authorization": []string{"Bearer " + token}} + + // Send initial request to use up the burst token + res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res1.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res1.StatusCode) + _, err = io.ReadAll(res1.Body) + require.NoError(t, err) + + var elapsedFastCounter atomic.Int64 + var wg sync.WaitGroup + + for range waitEntries + 1 { + wg.Add(1) + + go func() { + defer wg.Done() + + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + + elapsed := time.Since(start) + + if elapsed < 100*time.Millisecond { + elapsedFastCounter.Add(1) + } + + require.True(t, elapsed < 50*time.Millisecond || elapsed >= 700*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }() + } + + wg.Wait() + + // We only exit early on the 5th request as by the 5th request we have accumulated + // enough tokens to exceed the max wait duration + require.Equal(t, 1, int(elapsedFastCounter.Load())) + }) + }) + t.Run("unknown kid refresh exceeding burst waits until interval when MaxWait larger", func(t *testing.T) { t.Parallel() From 24025e4758663268065119e854e6d747ff4464e3 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 14:14:07 +0530 Subject: [PATCH 11/11] fix: default values and the comments --- router-tests/authentication_test.go | 107 ++++++++++++++-------------- router/pkg/config/config.go | 4 +- 2 files changed, 54 insertions(+), 57 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index a7689892a3..326b84749a 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -191,7 +191,6 @@ func TestAuthentication(t *testing.T) { header := http.Header{"Authorization": []string{"Bearer " + token}} - // Consume burst token res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) require.NoError(t, err) defer func() { _ = res1.Body.Close() }() @@ -199,7 +198,7 @@ func TestAuthentication(t *testing.T) { _, err = io.ReadAll(res1.Body) require.NoError(t, err) - // Next call should block but only up to MaxWait (< interval) + // Next call should exceed max wait so should return immediately start := time.Now() res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) require.NoError(t, err) @@ -213,17 +212,13 @@ func TestAuthentication(t *testing.T) { }) }) - // After consuming the single burst token, launch multiple requests in parallel. - // Each should block if the max limit has not been accumulated - t.Run("unknown kid refresh parallel exceeding burst waits up to MaxWait", func(t *testing.T) { + t.Run("unknown kid refresh exceeding burst waits until interval when max wait larger", func(t *testing.T) { t.Parallel() authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) - const waitEntries = 4 - authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ { URL: authServer.JWKSURL(), @@ -232,7 +227,7 @@ func TestAuthentication(t *testing.T) { Enabled: true, Interval: 1 * time.Second, Burst: 1, - MaxWait: waitEntries * time.Second, + MaxWait: 2 * time.Second, // larger than interval, so it can wait until next token }, }, }) @@ -247,7 +242,6 @@ func TestAuthentication(t *testing.T) { header := http.Header{"Authorization": []string{"Bearer " + token}} - // Send initial request to use up the burst token res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) require.NoError(t, err) defer func() { _ = res1.Body.Close() }() @@ -255,49 +249,31 @@ func TestAuthentication(t *testing.T) { _, err = io.ReadAll(res1.Body) require.NoError(t, err) - var elapsedFastCounter atomic.Int64 - var wg sync.WaitGroup - - for range waitEntries + 1 { - wg.Add(1) - - go func() { - defer wg.Done() - - start := time.Now() - res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) - require.NoError(t, err) - defer func() { _ = res2.Body.Close() }() - - elapsed := time.Since(start) - - if elapsed < 100*time.Millisecond { - elapsedFastCounter.Add(1) - } - - require.True(t, elapsed < 50*time.Millisecond || elapsed >= 700*time.Millisecond) - require.Equal(t, http.StatusUnauthorized, res2.StatusCode) - data, err := io.ReadAll(res2.Body) - require.NoError(t, err) - require.JSONEq(t, unauthorizedExpectedData, string(data)) - }() - } - - wg.Wait() + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + elapsed := time.Since(start) - // We only exit early on the 5th request as by the 5th request we have accumulated - // enough tokens to exceed the max wait duration - require.Equal(t, 1, int(elapsedFastCounter.Load())) + require.True(t, elapsed >= 600*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) }) }) - t.Run("unknown kid refresh exceeding burst waits until interval when MaxWait larger", func(t *testing.T) { + // After consuming the single burst token, launch multiple requests in parallel. + // Each should block if the max limit has not been accumulated + t.Run("unknown kid refresh parallel exceeding burst waits up to max wait", func(t *testing.T) { t.Parallel() authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) + const waitEntries = 4 + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ { URL: authServer.JWKSURL(), @@ -306,7 +282,7 @@ func TestAuthentication(t *testing.T) { Enabled: true, Interval: 1 * time.Second, Burst: 1, - MaxWait: 2 * time.Second, // larger than interval, so it can wait until next token + MaxWait: waitEntries * time.Second, }, }, }) @@ -321,7 +297,7 @@ func TestAuthentication(t *testing.T) { header := http.Header{"Authorization": []string{"Bearer " + token}} - // Consume burst token + // Send initial request to use up the burst token res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) require.NoError(t, err) defer func() { _ = res1.Body.Close() }() @@ -329,18 +305,39 @@ func TestAuthentication(t *testing.T) { _, err = io.ReadAll(res1.Body) require.NoError(t, err) - // Next call should wait until limiter allows next refresh (~1s) - start := time.Now() - res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) - require.NoError(t, err) - defer func() { _ = res2.Body.Close() }() - elapsed := time.Since(start) + var elapsedFastCounter atomic.Int64 + var wg sync.WaitGroup - require.True(t, elapsed >= 600*time.Millisecond) - require.Equal(t, http.StatusUnauthorized, res2.StatusCode) - data, err := io.ReadAll(res2.Body) - require.NoError(t, err) - require.JSONEq(t, unauthorizedExpectedData, string(data)) + for range waitEntries + 1 { + wg.Add(1) + + go func() { + defer wg.Done() + + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + + elapsed := time.Since(start) + + if elapsed < 100*time.Millisecond { + elapsedFastCounter.Add(1) + } + + require.True(t, elapsed < 50*time.Millisecond || elapsed >= 700*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }() + } + + wg.Wait() + + // We only exit early on the 5th request as by the 5th request we have accumulated + // enough tokens to exceed the max wait duration + require.Equal(t, 1, int(elapsedFastCounter.Load())) }) }) diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 8f410e6700..57591de31e 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -483,8 +483,8 @@ type JWKSConfiguration struct { type RefreshUnknownKID struct { Enabled bool `yaml:"enabled" envDefault:"false"` - MaxWait time.Duration `yaml:"max_wait" envDefault:"10s"` - Interval time.Duration `yaml:"interval" envDefault:"1m"` + MaxWait time.Duration `yaml:"max_wait" envDefault:"2m"` + Interval time.Duration `yaml:"interval" envDefault:"30s"` Burst int `yaml:"burst" envDefault:"2"` }