From ba7534f274c4231d5e79687845ae97887ba81179 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 15 Sep 2025 16:12:28 +0530 Subject: [PATCH 01/45] fix: jwt validation blocks on multiple requests --- router-tests/authentication_test.go | 45 +++++++++++++++++-- router-tests/header_set_test.go | 2 +- router-tests/jwks/jwks.go | 12 ++++- router-tests/testenv/testenv.go | 21 +++++++++ .../pkg/authentication/jwks_token_decoder.go | 7 +-- 5 files changed, 75 insertions(+), 12 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 38d2715a56..561e36a8a9 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -828,6 +828,43 @@ func TestHttpJwksAuthorization(t *testing.T) { }) }) + t.Run("authentication should not block with an invalid token on multiple calls", func(t *testing.T) { + t.Parallel() + + authenticators, authServer := ConfigureAuth(t) + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // This token has a static keyid + token, err := authServer.TokenForKID("wg_static_kid", nil, true) + maxDuration := 5 * time.Second + + xEnv.WaitForTest(maxDuration, func() { + require.NoError(t, err) + for range 5 { + func() { + // Operations with an invalid token should fail + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { + _ = res.Body.Close() + }() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }() + } + }) + }) + }) + } func TestNonHttpAuthorization(t *testing.T) { @@ -1182,7 +1219,7 @@ func TestAlgorithmMismatch(t *testing.T) { authenticators := []authentication.Authenticator{authenticator} - token, err := authServer.TokenForKID(crypto.KID(), nil) + token, err := authServer.TokenForKID(crypto.KID(), nil, false) require.NoError(t, err) return token, authenticators @@ -1307,7 +1344,7 @@ func TestOidcDiscovery(t *testing.T) { tokens := make(map[string]string) for _, c := range crypto { - token, err := authServer.TokenForKID(c.KID(), nil) + token, err := authServer.TokenForKID(c.KID(), nil, false) require.NoError(t, err) tokens[c.KID()] = token @@ -1421,7 +1458,7 @@ func TestMultipleKeys(t *testing.T) { tokens := make(map[string]string) for _, c := range crypto { - token, err := authServer.TokenForKID(c.KID(), nil) + token, err := authServer.TokenForKID(c.KID(), nil, false) require.NoError(t, err) tokens[c.KID()] = token @@ -1604,7 +1641,7 @@ func TestSupportedAlgorithms(t *testing.T) { authenticators := []authentication.Authenticator{authenticator} - token, err := authServer.TokenForKID(crypto.KID(), nil) + token, err := authServer.TokenForKID(crypto.KID(), nil, false) require.NoError(t, err) return token, authenticators diff --git a/router-tests/header_set_test.go b/router-tests/header_set_test.go index 763663e2b5..a8f15d69af 100644 --- a/router-tests/header_set_test.go +++ b/router-tests/header_set_test.go @@ -275,7 +275,7 @@ func TestHeaderSetWithExpression(t *testing.T) { authenticator, err := authentication.NewHttpHeaderAuthenticator(authOptions) require.NoError(t, err) - token, err := authServer.TokenForKID(rsa1.KID(), map[string]any{"user_id": "TestId"}) + token, err := authServer.TokenForKID(rsa1.KID(), map[string]any{"user_id": "TestId"}, false) require.NoError(t, err) testenv.Run(t, &testenv.Config{ diff --git a/router-tests/jwks/jwks.go b/router-tests/jwks/jwks.go index 080879aa0c..78a414ab9b 100644 --- a/router-tests/jwks/jwks.go +++ b/router-tests/jwks/jwks.go @@ -50,11 +50,19 @@ func (s *Server) Token(claims map[string]any) (string, error) { return "", jwt.ErrInvalidKey } -func (s *Server) TokenForKID(kid string, claims map[string]any) (string, error) { +func (s *Server) TokenForKID(kid string, claims map[string]any, useInvalidKid bool) (string, error) { provider, ok := s.providers[kid] - if !ok { + if !useInvalidKid && !ok { return "", jwt.ErrInvalidKey + } else if useInvalidKid { + // If we don't care about the kid we don't care about the provider + // we just get the first random provider provided + for _, pr := range s.providers { + provider = pr + break + } } + token := jwt.NewWithClaims(provider.SigningMethod(), jwt.MapClaims(claims)) token.Header[jwkset.HeaderKID] = kid return token.SignedString(provider.PrivateKey()) diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index bc6cf1df84..30cc7929bf 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -2532,6 +2532,27 @@ func (e *Environment) WaitForConnectionCount(desiredCount uint64, timeout time.D } } +func (e *Environment) WaitForTest(timeout time.Duration, testFunction func()) { + e.t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + done := make(chan struct{}) + go func() { + defer close(done) + testFunction() + }() + + select { + case <-done: + return + case <-ctx.Done(): + e.t.Fatalf("test timed out, want %d", timeout) + return + } +} + type EngineStatisticAssertion struct { Subscriptions int64 Connections int64 diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index 53b252632c..cae860409f 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -10,10 +10,8 @@ import ( "github.com/MicahParks/jwkset" "github.com/MicahParks/keyfunc/v3" "github.com/golang-jwt/jwt/v5" - "go.uber.org/zap" - "golang.org/x/time/rate" - "github.com/wundergraph/cosmo/router/internal/httpclient" + "go.uber.org/zap" ) type TokenDecoder interface { @@ -95,8 +93,7 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS HTTPURLs: map[string]jwkset.Storage{ c.URL: store, }, - PrioritizeHTTP: true, - RefreshUnknownKID: rate.NewLimiter(rate.Every(5*time.Minute), 1), + PrioritizeHTTP: true, } jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions) From 79adec8d1b51bd4ab61d4a971cb01a29c2a44510 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 17:55:10 +0530 Subject: [PATCH 02/45] fix: review comments --- router-tests/authentication_test.go | 11 ++++++++--- router-tests/jwks/jwks.go | 6 +++--- router/go.mod | 2 +- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 561e36a8a9..0e6f006285 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -839,10 +839,13 @@ func TestHttpJwksAuthorization(t *testing.T) { }, func(t *testing.T, xEnv *testenv.Environment) { // This token has a static keyid token, err := authServer.TokenForKID("wg_static_kid", nil, true) + require.NoError(t, err) + maxDuration := 5 * time.Second - xEnv.WaitForTest(maxDuration, func() { - require.NoError(t, err) + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) for range 5 { func() { // Operations with an invalid token should fail @@ -861,7 +864,9 @@ func TestHttpJwksAuthorization(t *testing.T) { require.JSONEq(t, unauthorizedExpectedData, string(data)) }() } - }) + }() + + testenv.AwaitChannelWithT(t, maxDuration, doneCh, func(t *testing.T, _ struct{}) {}, "test timed out") }) }) diff --git a/router-tests/jwks/jwks.go b/router-tests/jwks/jwks.go index 78a414ab9b..07e0ada272 100644 --- a/router-tests/jwks/jwks.go +++ b/router-tests/jwks/jwks.go @@ -50,11 +50,11 @@ func (s *Server) Token(claims map[string]any) (string, error) { return "", jwt.ErrInvalidKey } -func (s *Server) TokenForKID(kid string, claims map[string]any, useInvalidKid bool) (string, error) { +func (s *Server) TokenForKID(kid string, claims map[string]any, useInvalidKID bool) (string, error) { provider, ok := s.providers[kid] - if !useInvalidKid && !ok { + if !useInvalidKID && !ok { return "", jwt.ErrInvalidKey - } else if useInvalidKid { + } else if useInvalidKID { // If we don't care about the kid we don't care about the provider // we just get the first random provider provided for _, pr := range s.providers { diff --git a/router/go.mod b/router/go.mod index ae51779a4f..dc8761f12e 100644 --- a/router/go.mod +++ b/router/go.mod @@ -83,7 +83,6 @@ require ( go.uber.org/ratelimit v0.3.1 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 golang.org/x/text v0.23.0 - golang.org/x/time v0.9.0 ) require ( @@ -166,6 +165,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/net v0.38.0 // indirect + golang.org/x/time v0.9.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect From ca0aaaca2130254aeda27a1aafab0d587c26ada6 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 17:59:58 +0530 Subject: [PATCH 03/45] fix: cleanup --- router-tests/testenv/testenv.go | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 30cc7929bf..bc6cf1df84 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -2532,27 +2532,6 @@ func (e *Environment) WaitForConnectionCount(desiredCount uint64, timeout time.D } } -func (e *Environment) WaitForTest(timeout time.Duration, testFunction func()) { - e.t.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - done := make(chan struct{}) - go func() { - defer close(done) - testFunction() - }() - - select { - case <-done: - return - case <-ctx.Done(): - e.t.Fatalf("test timed out, want %d", timeout) - return - } -} - type EngineStatisticAssertion struct { Subscriptions int64 Connections int64 From e0f1e53e2e20843d4cf505d4b947fdbaf4f91957 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 18:45:01 +0530 Subject: [PATCH 04/45] fix: require equals --- router-tests/authentication_test.go | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 0e6f006285..06b2beb49b 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -843,9 +843,7 @@ func TestHttpJwksAuthorization(t *testing.T) { maxDuration := 5 * time.Second - doneCh := make(chan struct{}) - go func() { - defer close(doneCh) + require.Eventually(t, func() bool { for range 5 { func() { // Operations with an invalid token should fail @@ -864,9 +862,8 @@ func TestHttpJwksAuthorization(t *testing.T) { require.JSONEq(t, unauthorizedExpectedData, string(data)) }() } - }() - - testenv.AwaitChannelWithT(t, maxDuration, doneCh, func(t *testing.T, _ struct{}) {}, "test timed out") + return true + }, maxDuration, 10*time.Millisecond) }) }) From 3166aa801ca5bd28051e695d207023102418cc6e Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 18:47:32 +0530 Subject: [PATCH 05/45] Revert "fix: require equals" This reverts commit e0f1e53e2e20843d4cf505d4b947fdbaf4f91957. --- router-tests/authentication_test.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 06b2beb49b..0e6f006285 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -843,7 +843,9 @@ func TestHttpJwksAuthorization(t *testing.T) { maxDuration := 5 * time.Second - require.Eventually(t, func() bool { + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) for range 5 { func() { // Operations with an invalid token should fail @@ -862,8 +864,9 @@ func TestHttpJwksAuthorization(t *testing.T) { require.JSONEq(t, unauthorizedExpectedData, string(data)) }() } - return true - }, maxDuration, 10*time.Millisecond) + }() + + testenv.AwaitChannelWithT(t, maxDuration, doneCh, func(t *testing.T, _ struct{}) {}, "test timed out") }) }) From 3e9ff2aaf4a8a82a5a8b433d8ee1c3b61907f221 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 19:01:30 +0530 Subject: [PATCH 06/45] fix: tests --- router-tests/authentication_test.go | 8 ++------ router-tests/testenv/utils.go | 12 ++++++++++++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 0e6f006285..9a41fda9b0 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -843,9 +843,7 @@ func TestHttpJwksAuthorization(t *testing.T) { maxDuration := 5 * time.Second - doneCh := make(chan struct{}) - go func() { - defer close(doneCh) + testenv.AwaitFunc(t, maxDuration, func() { for range 5 { func() { // Operations with an invalid token should fail @@ -864,9 +862,7 @@ func TestHttpJwksAuthorization(t *testing.T) { require.JSONEq(t, unauthorizedExpectedData, string(data)) }() } - }() - - testenv.AwaitChannelWithT(t, maxDuration, doneCh, func(t *testing.T, _ struct{}) {}, "test timed out") + }) }) }) diff --git a/router-tests/testenv/utils.go b/router-tests/testenv/utils.go index bd4f1842db..19d13ba27c 100644 --- a/router-tests/testenv/utils.go +++ b/router-tests/testenv/utils.go @@ -28,3 +28,15 @@ func AwaitChannelWithCloseWithT[A any](t *testing.T, timeout time.Duration, ch < require.Fail(t, "unable to receive message before timeout", msgAndArgs...) } } + +func AwaitFunc(t *testing.T, timeout time.Duration, testFunction func()) { + t.Helper() + + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + testFunction() + }() + + AwaitChannelWithT(t, timeout, doneCh, func(t *testing.T, _ struct{}) {}, "the test function timed out") +} From 6978e0d52bdff4335ca4709204249a49d06405bf Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Fri, 19 Sep 2025 00:46:31 +0530 Subject: [PATCH 07/45] fix: tests --- router-tests/jwks/jwks.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/router-tests/jwks/jwks.go b/router-tests/jwks/jwks.go index 07e0ada272..6b77a76812 100644 --- a/router-tests/jwks/jwks.go +++ b/router-tests/jwks/jwks.go @@ -52,15 +52,14 @@ func (s *Server) Token(claims map[string]any) (string, error) { func (s *Server) TokenForKID(kid string, claims map[string]any, useInvalidKID bool) (string, error) { provider, ok := s.providers[kid] - if !useInvalidKID && !ok { - return "", jwt.ErrInvalidKey - } else if useInvalidKID { - // If we don't care about the kid we don't care about the provider - // we just get the first random provider provided + if useInvalidKID { + // If we don't care about the kid, use any available provider for _, pr := range s.providers { provider = pr break } + } else if !ok { + return "", jwt.ErrInvalidKey } token := jwt.NewWithClaims(provider.SigningMethod(), jwt.MapClaims(claims)) From a04b95323b2371fba41a8eddcdf5ed949b60ca59 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 01:48:40 +0530 Subject: [PATCH 08/45] fix: make rate limit values configurable --- router-tests/authentication_test.go | 183 ++++++++++++++---- router/core/supervisor_instance.go | 5 + .../pkg/authentication/jwks_token_decoder.go | 19 +- router/pkg/config/config.go | 13 +- router/pkg/config/config.schema.json | 26 +++ router/pkg/config/testdata/config_full.json | 15 ++ 6 files changed, 219 insertions(+), 42 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 9a41fda9b0..ab8650c381 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -56,6 +56,151 @@ func TestAuthentication(t *testing.T) { }) }) + t.Run("unknown kid refresh blocks when burst exceeded", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, + Burst: 1, + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + header := http.Header{"Authorization": []string{"Bearer " + token}} + + res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res1.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res1.StatusCode) + _, err = io.ReadAll(res1.Body) + require.NoError(t, err) + + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + elapsed := time.Since(start) + + require.True(t, elapsed >= 600*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + + t.Run("unknown kid refresh does not block when burst not exceeded", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, + Burst: 1, + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + header := http.Header{"Authorization": []string{"Bearer " + token}} + + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + _, err = io.ReadAll(res.Body) + require.NoError(t, err) + + // Wait for interval so next refresh is within burst budget + time.Sleep(1200 * time.Millisecond) + + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + elapsed := time.Since(start) + require.True(t, elapsed < 100*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + + t.Run("authentication should not block with unknown kid when refresh is disabled", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 100 * time.Millisecond, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Create a token signed with a valid key but with an unknown kid header + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + maxDuration := 4 * time.Second + testenv.AwaitFunc(t, maxDuration, func() { + for range 5 { + func() { + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }() + } + }) + }) + }) + t.Run("invalid token", func(t *testing.T) { t.Parallel() @@ -828,44 +973,6 @@ func TestHttpJwksAuthorization(t *testing.T) { }) }) - t.Run("authentication should not block with an invalid token on multiple calls", func(t *testing.T) { - t.Parallel() - - authenticators, authServer := ConfigureAuth(t) - testenv.Run(t, &testenv.Config{ - RouterOptions: []core.Option{ - core.WithAccessController(core.NewAccessController(authenticators, true)), - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - // This token has a static keyid - token, err := authServer.TokenForKID("wg_static_kid", nil, true) - require.NoError(t, err) - - maxDuration := 5 * time.Second - - testenv.AwaitFunc(t, maxDuration, func() { - for range 5 { - func() { - // Operations with an invalid token should fail - header := http.Header{ - "Authorization": []string{"Bearer " + token}, - } - res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) - require.NoError(t, err) - defer func() { - _ = res.Body.Close() - }() - require.Equal(t, http.StatusUnauthorized, res.StatusCode) - require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader)) - data, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.JSONEq(t, unauthorizedExpectedData, string(data)) - }() - } - }) - }) - }) - } func TestNonHttpAuthorization(t *testing.T) { diff --git a/router/core/supervisor_instance.go b/router/core/supervisor_instance.go index 09605f6d9e..b2ca6050c1 100644 --- a/router/core/supervisor_instance.go +++ b/router/core/supervisor_instance.go @@ -265,6 +265,11 @@ func setupAuthenticators(ctx context.Context, logger *zap.Logger, cfg *config.Co KeyId: jwks.KeyId, Audiences: jwks.Audiences, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: jwks.RefreshUnknownKID.Enabled, + Interval: jwks.RefreshUnknownKID.Interval, + Burst: jwks.RefreshUnknownKID.Burst, + }, }) } diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index cae860409f..0cc06ce962 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -7,6 +7,8 @@ import ( "net/http" "time" + "golang.org/x/time/rate" + "github.com/MicahParks/jwkset" "github.com/MicahParks/keyfunc/v3" "github.com/golang-jwt/jwt/v5" @@ -47,6 +49,14 @@ type JWKSConfig struct { KeyId string Audiences []string + + RefreshUnknownKID RefreshUnknownKIDConfig +} + +type RefreshUnknownKIDConfig struct { + Enabled bool + Interval time.Duration + Burst int } type audKey struct { @@ -89,11 +99,18 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS audiencesMap[key] = getAudienceSet(c.Audiences) + // Configure the rate limiter for refreshing unknown KIDs + var refreshLimiter *rate.Limiter + if c.RefreshUnknownKID.Enabled { + refreshLimiter = rate.NewLimiter(rate.Every(c.RefreshUnknownKID.Interval), c.RefreshUnknownKID.Burst) + } + jwksetHTTPClientOptions := jwkset.HTTPClientOptions{ HTTPURLs: map[string]jwkset.Storage{ c.URL: store, }, - PrioritizeHTTP: true, + PrioritizeHTTP: true, + RefreshUnknownKID: refreshLimiter, } jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions) diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index c33310d657..9a87357953 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -466,9 +466,10 @@ type OverridesConfiguration struct { } type JWKSConfiguration struct { - URL string `yaml:"url"` - Algorithms []string `yaml:"algorithms"` - RefreshInterval time.Duration `yaml:"refresh_interval" envDefault:"1m"` + URL string `yaml:"url"` + Algorithms []string `yaml:"algorithms"` + RefreshInterval time.Duration `yaml:"refresh_interval" envDefault:"1m"` + RefreshUnknownKID RefreshUnknownKIDConfig `yaml:",inline"` // For secret based where we need to create a jwk entry with // a key id and algorithm @@ -480,6 +481,12 @@ type JWKSConfiguration struct { Audiences []string `yaml:"audiences"` } +type RefreshUnknownKIDConfig struct { + Enabled bool `yaml:"enabled" envDefault:"false"` + Interval time.Duration `yaml:"interval" envDefault:"1m"` + Burst int `yaml:"burst" envDefault:"2"` +} + type HeaderSource struct { Type string `yaml:"type"` Name string `yaml:"name"` diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index c45890172e..0bdd4b323c 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1716,6 +1716,32 @@ }, "description": "The interval at which the JWKs are refreshed. The period is specified as a string with a number and a unit, e.g. 10ms, 1s, 1m, 1h. The supported units are 'ms', 's', 'm', 'h'.", "default": "1m" + }, + "refresh_unknown_kid": { + "type": "object", + "description": "Controls rate-limited refresh behavior when a JWT KID is unknown.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable refresh attempts on unknown KID.", + "default": false + }, + "interval": { + "type": "string", + "description": "Token refill interval for the rate limiter.", + "default": "1m", + "duration": { + "minimum": "1s" + } + }, + "burst": { + "type": "integer", + "description": "Burst size for the rate limiter.", + "default": 2, + "minimum": 1 + } + } } }, "oneOf": [ diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index ec8b47da50..ddff7fb15d 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -480,6 +480,11 @@ "RS256" ], "RefreshInterval": 60000000000, + "RefreshUnknownKID": { + "Enabled": false, + "Interval": 0, + "Burst": 0 + }, "Secret": "", "Algorithm": "", "KeyId": "", @@ -492,6 +497,11 @@ "ES256" ], "RefreshInterval": 120000000000, + "RefreshUnknownKID": { + "Enabled": false, + "Interval": 0, + "Burst": 0 + }, "Secret": "", "Algorithm": "", "KeyId": "", @@ -501,6 +511,11 @@ "URL": "https://example.com/.well-known/jwks3.json", "Algorithms": null, "RefreshInterval": 0, + "RefreshUnknownKID": { + "Enabled": false, + "Interval": 0, + "Burst": 0 + }, "Secret": "", "Algorithm": "", "KeyId": "", From 74ac0c5127d3df06d55868944459809fadc9bcec Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 02:12:20 +0530 Subject: [PATCH 09/45] fix: changes --- router-tests/authentication_test.go | 109 +++++++++++++++++- router/core/supervisor_instance.go | 1 + router/go.mod | 2 +- .../pkg/authentication/jwks_token_decoder.go | 16 +-- router/pkg/config/config.go | 11 +- router/pkg/config/config.schema.json | 8 ++ router/pkg/config/fixtures/full.yaml | 5 + router/pkg/config/testdata/config_full.json | 9 +- 8 files changed, 143 insertions(+), 18 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index ab8650c381..0d7fbcb37f 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -98,7 +98,7 @@ func TestAuthentication(t *testing.T) { defer func() { _ = res2.Body.Close() }() elapsed := time.Since(start) - require.True(t, elapsed >= 600*time.Millisecond) + require.True(t, elapsed >= 700*time.Millisecond) require.Equal(t, http.StatusUnauthorized, res2.StatusCode) data, err := io.ReadAll(res2.Body) require.NoError(t, err) @@ -157,6 +157,113 @@ func TestAuthentication(t *testing.T) { }) }) + // Since the rate limiter knows that the limit will definitely be exceeded it exits + // immediately without waiting + t.Run("unknown kid refresh interval exceeding max wait returns immediately", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, // next token available in ~1s + Burst: 1, + MaxWait: 700 * time.Millisecond, // cap wait well below interval + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + header := http.Header{"Authorization": []string{"Bearer " + token}} + + // Consume burst token + res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res1.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res1.StatusCode) + _, err = io.ReadAll(res1.Body) + require.NoError(t, err) + + // Next call should block but only up to MaxWait (< interval) + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + elapsed := time.Since(start) + require.True(t, elapsed < 100*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + + t.Run("unknown kid refresh exceeding burst waits until interval when MaxWait larger", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, + Burst: 1, + MaxWait: 2 * time.Second, // larger than interval, so it can wait until next token + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + header := http.Header{"Authorization": []string{"Bearer " + token}} + + // Consume burst token + res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res1.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res1.StatusCode) + _, err = io.ReadAll(res1.Body) + require.NoError(t, err) + + // Next call should wait until limiter allows next refresh (~1s) + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + elapsed := time.Since(start) + + require.True(t, elapsed >= 600*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + t.Run("authentication should not block with unknown kid when refresh is disabled", func(t *testing.T) { t.Parallel() diff --git a/router/core/supervisor_instance.go b/router/core/supervisor_instance.go index b2ca6050c1..7a43d3f138 100644 --- a/router/core/supervisor_instance.go +++ b/router/core/supervisor_instance.go @@ -267,6 +267,7 @@ func setupAuthenticators(ctx context.Context, logger *zap.Logger, cfg *config.Co Audiences: jwks.Audiences, RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ Enabled: jwks.RefreshUnknownKID.Enabled, + MaxWait: jwks.RefreshUnknownKID.MaxWait, Interval: jwks.RefreshUnknownKID.Interval, Burst: jwks.RefreshUnknownKID.Burst, }, diff --git a/router/go.mod b/router/go.mod index dc8761f12e..ae51779a4f 100644 --- a/router/go.mod +++ b/router/go.mod @@ -83,6 +83,7 @@ require ( go.uber.org/ratelimit v0.3.1 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 golang.org/x/text v0.23.0 + golang.org/x/time v0.9.0 ) require ( @@ -165,7 +166,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/net v0.38.0 // indirect - golang.org/x/time v0.9.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index 0cc06ce962..c6ae6c794e 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -57,6 +57,7 @@ type RefreshUnknownKIDConfig struct { Enabled bool Interval time.Duration Burst int + MaxWait time.Duration } type audKey struct { @@ -99,18 +100,17 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS audiencesMap[key] = getAudienceSet(c.Audiences) - // Configure the rate limiter for refreshing unknown KIDs - var refreshLimiter *rate.Limiter - if c.RefreshUnknownKID.Enabled { - refreshLimiter = rate.NewLimiter(rate.Every(c.RefreshUnknownKID.Interval), c.RefreshUnknownKID.Burst) - } - jwksetHTTPClientOptions := jwkset.HTTPClientOptions{ HTTPURLs: map[string]jwkset.Storage{ c.URL: store, }, - PrioritizeHTTP: true, - RefreshUnknownKID: refreshLimiter, + PrioritizeHTTP: true, + } + + // Configure the rate limiter for refreshing unknown KIDs + if c.RefreshUnknownKID.Enabled { + jwksetHTTPClientOptions.RefreshUnknownKID = rate.NewLimiter(rate.Every(c.RefreshUnknownKID.Interval), c.RefreshUnknownKID.Burst) + jwksetHTTPClientOptions.RateLimitWaitMax = c.RefreshUnknownKID.MaxWait } jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions) diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 9a87357953..8f410e6700 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -466,10 +466,10 @@ type OverridesConfiguration struct { } type JWKSConfiguration struct { - URL string `yaml:"url"` - Algorithms []string `yaml:"algorithms"` - RefreshInterval time.Duration `yaml:"refresh_interval" envDefault:"1m"` - RefreshUnknownKID RefreshUnknownKIDConfig `yaml:",inline"` + URL string `yaml:"url"` + Algorithms []string `yaml:"algorithms"` + RefreshInterval time.Duration `yaml:"refresh_interval" envDefault:"1m"` + RefreshUnknownKID RefreshUnknownKID `yaml:"refresh_unknown_kid"` // For secret based where we need to create a jwk entry with // a key id and algorithm @@ -481,8 +481,9 @@ type JWKSConfiguration struct { Audiences []string `yaml:"audiences"` } -type RefreshUnknownKIDConfig struct { +type RefreshUnknownKID struct { Enabled bool `yaml:"enabled" envDefault:"false"` + MaxWait time.Duration `yaml:"max_wait" envDefault:"10s"` Interval time.Duration `yaml:"interval" envDefault:"1m"` Burst int `yaml:"burst" envDefault:"2"` } diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 0bdd4b323c..570d5f8086 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1727,6 +1727,14 @@ "description": "Enable refresh attempts on unknown KID.", "default": false }, + "max_wait": { + "type": "string", + "description": "Maximum time to wait for a refresh permit before giving up.", + "default": "10s", + "duration": { + "minimum": "0s" + } + }, "interval": { "type": "string", "description": "Token refill interval for the rate limiter.", diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index afd160f911..a43691cc12 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -275,6 +275,11 @@ authentication: - url: 'https://example.com/.well-known/jwks2.json' refresh_interval: 2m algorithms: ['RS256', 'ES256'] + refresh_unknown_kid: + enabled: true + max_wait: 10s + interval: 5s + burst: 3 - url: 'https://example.com/.well-known/jwks3.json' header_name: Authorization header_value_prefix: Bearer diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index ddff7fb15d..ba295c3935 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -482,6 +482,7 @@ "RefreshInterval": 60000000000, "RefreshUnknownKID": { "Enabled": false, + "MaxWait": 0, "Interval": 0, "Burst": 0 }, @@ -498,9 +499,10 @@ ], "RefreshInterval": 120000000000, "RefreshUnknownKID": { - "Enabled": false, - "Interval": 0, - "Burst": 0 + "Enabled": true, + "MaxWait": 10000000000, + "Interval": 5000000000, + "Burst": 3 }, "Secret": "", "Algorithm": "", @@ -513,6 +515,7 @@ "RefreshInterval": 0, "RefreshUnknownKID": { "Enabled": false, + "MaxWait": 0, "Interval": 0, "Burst": 0 }, From fd38a250e6e7dd9f1753d56f2773d1173cc72ead Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 02:43:45 +0530 Subject: [PATCH 10/45] fix: tests --- router-tests/authentication_test.go | 80 +++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 0d7fbcb37f..a7689892a3 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -8,6 +8,8 @@ import ( "io" "net/http" "strings" + "sync" + "sync/atomic" "testing" "time" @@ -211,6 +213,84 @@ func TestAuthentication(t *testing.T) { }) }) + // After consuming the single burst token, launch multiple requests in parallel. + // Each should block if the max limit has not been accumulated + t.Run("unknown kid refresh parallel exceeding burst waits up to MaxWait", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + const waitEntries = 4 + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, + Burst: 1, + MaxWait: waitEntries * time.Second, + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + header := http.Header{"Authorization": []string{"Bearer " + token}} + + // Send initial request to use up the burst token + res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res1.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res1.StatusCode) + _, err = io.ReadAll(res1.Body) + require.NoError(t, err) + + var elapsedFastCounter atomic.Int64 + var wg sync.WaitGroup + + for range waitEntries + 1 { + wg.Add(1) + + go func() { + defer wg.Done() + + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + + elapsed := time.Since(start) + + if elapsed < 100*time.Millisecond { + elapsedFastCounter.Add(1) + } + + require.True(t, elapsed < 50*time.Millisecond || elapsed >= 700*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }() + } + + wg.Wait() + + // We only exit early on the 5th request as by the 5th request we have accumulated + // enough tokens to exceed the max wait duration + require.Equal(t, 1, int(elapsedFastCounter.Load())) + }) + }) + t.Run("unknown kid refresh exceeding burst waits until interval when MaxWait larger", func(t *testing.T) { t.Parallel() From 24025e4758663268065119e854e6d747ff4464e3 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 14:14:07 +0530 Subject: [PATCH 11/45] fix: default values and the comments --- router-tests/authentication_test.go | 107 ++++++++++++++-------------- router/pkg/config/config.go | 4 +- 2 files changed, 54 insertions(+), 57 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index a7689892a3..326b84749a 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -191,7 +191,6 @@ func TestAuthentication(t *testing.T) { header := http.Header{"Authorization": []string{"Bearer " + token}} - // Consume burst token res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) require.NoError(t, err) defer func() { _ = res1.Body.Close() }() @@ -199,7 +198,7 @@ func TestAuthentication(t *testing.T) { _, err = io.ReadAll(res1.Body) require.NoError(t, err) - // Next call should block but only up to MaxWait (< interval) + // Next call should exceed max wait so should return immediately start := time.Now() res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) require.NoError(t, err) @@ -213,17 +212,13 @@ func TestAuthentication(t *testing.T) { }) }) - // After consuming the single burst token, launch multiple requests in parallel. - // Each should block if the max limit has not been accumulated - t.Run("unknown kid refresh parallel exceeding burst waits up to MaxWait", func(t *testing.T) { + t.Run("unknown kid refresh exceeding burst waits until interval when max wait larger", func(t *testing.T) { t.Parallel() authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) - const waitEntries = 4 - authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ { URL: authServer.JWKSURL(), @@ -232,7 +227,7 @@ func TestAuthentication(t *testing.T) { Enabled: true, Interval: 1 * time.Second, Burst: 1, - MaxWait: waitEntries * time.Second, + MaxWait: 2 * time.Second, // larger than interval, so it can wait until next token }, }, }) @@ -247,7 +242,6 @@ func TestAuthentication(t *testing.T) { header := http.Header{"Authorization": []string{"Bearer " + token}} - // Send initial request to use up the burst token res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) require.NoError(t, err) defer func() { _ = res1.Body.Close() }() @@ -255,49 +249,31 @@ func TestAuthentication(t *testing.T) { _, err = io.ReadAll(res1.Body) require.NoError(t, err) - var elapsedFastCounter atomic.Int64 - var wg sync.WaitGroup - - for range waitEntries + 1 { - wg.Add(1) - - go func() { - defer wg.Done() - - start := time.Now() - res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) - require.NoError(t, err) - defer func() { _ = res2.Body.Close() }() - - elapsed := time.Since(start) - - if elapsed < 100*time.Millisecond { - elapsedFastCounter.Add(1) - } - - require.True(t, elapsed < 50*time.Millisecond || elapsed >= 700*time.Millisecond) - require.Equal(t, http.StatusUnauthorized, res2.StatusCode) - data, err := io.ReadAll(res2.Body) - require.NoError(t, err) - require.JSONEq(t, unauthorizedExpectedData, string(data)) - }() - } - - wg.Wait() + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + elapsed := time.Since(start) - // We only exit early on the 5th request as by the 5th request we have accumulated - // enough tokens to exceed the max wait duration - require.Equal(t, 1, int(elapsedFastCounter.Load())) + require.True(t, elapsed >= 600*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) }) }) - t.Run("unknown kid refresh exceeding burst waits until interval when MaxWait larger", func(t *testing.T) { + // After consuming the single burst token, launch multiple requests in parallel. + // Each should block if the max limit has not been accumulated + t.Run("unknown kid refresh parallel exceeding burst waits up to max wait", func(t *testing.T) { t.Parallel() authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) + const waitEntries = 4 + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ { URL: authServer.JWKSURL(), @@ -306,7 +282,7 @@ func TestAuthentication(t *testing.T) { Enabled: true, Interval: 1 * time.Second, Burst: 1, - MaxWait: 2 * time.Second, // larger than interval, so it can wait until next token + MaxWait: waitEntries * time.Second, }, }, }) @@ -321,7 +297,7 @@ func TestAuthentication(t *testing.T) { header := http.Header{"Authorization": []string{"Bearer " + token}} - // Consume burst token + // Send initial request to use up the burst token res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) require.NoError(t, err) defer func() { _ = res1.Body.Close() }() @@ -329,18 +305,39 @@ func TestAuthentication(t *testing.T) { _, err = io.ReadAll(res1.Body) require.NoError(t, err) - // Next call should wait until limiter allows next refresh (~1s) - start := time.Now() - res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) - require.NoError(t, err) - defer func() { _ = res2.Body.Close() }() - elapsed := time.Since(start) + var elapsedFastCounter atomic.Int64 + var wg sync.WaitGroup - require.True(t, elapsed >= 600*time.Millisecond) - require.Equal(t, http.StatusUnauthorized, res2.StatusCode) - data, err := io.ReadAll(res2.Body) - require.NoError(t, err) - require.JSONEq(t, unauthorizedExpectedData, string(data)) + for range waitEntries + 1 { + wg.Add(1) + + go func() { + defer wg.Done() + + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + + elapsed := time.Since(start) + + if elapsed < 100*time.Millisecond { + elapsedFastCounter.Add(1) + } + + require.True(t, elapsed < 50*time.Millisecond || elapsed >= 700*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }() + } + + wg.Wait() + + // We only exit early on the 5th request as by the 5th request we have accumulated + // enough tokens to exceed the max wait duration + require.Equal(t, 1, int(elapsedFastCounter.Load())) }) }) diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 8f410e6700..57591de31e 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -483,8 +483,8 @@ type JWKSConfiguration struct { type RefreshUnknownKID struct { Enabled bool `yaml:"enabled" envDefault:"false"` - MaxWait time.Duration `yaml:"max_wait" envDefault:"10s"` - Interval time.Duration `yaml:"interval" envDefault:"1m"` + MaxWait time.Duration `yaml:"max_wait" envDefault:"2m"` + Interval time.Duration `yaml:"interval" envDefault:"30s"` Burst int `yaml:"burst" envDefault:"2"` } From 2a3c31d6d9d60ae2840c4a2de62e13b4edf1b2cd Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 16:25:51 +0530 Subject: [PATCH 12/45] feat: allow algorithm be unspecified in jwks --- .../pkg/authentication/jwks_token_decoder.go | 2 +- router/pkg/authentication/validation_store.go | 52 +++-- .../authentication/validation_store_test.go | 178 ++++++++++++++++++ router/pkg/config/config.go | 3 +- 4 files changed, 214 insertions(+), 21 deletions(-) create mode 100644 router/pkg/authentication/validation_store_test.go diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index c6ae6c794e..fb4cbe6de1 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -90,7 +90,7 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS l.Error("Failed to refresh HTTP JWK Set from remote HTTP resource.", zap.Error(err)) }, RefreshInterval: c.RefreshInterval, - Storage: NewValidationStore(logger, nil, c.AllowedAlgorithms), + Storage: NewValidationStore(logger, nil, c.AllowedAlgorithms, c.AllowEmptyAlgorithm), } store, err := jwkset.NewStorageFromHTTP(c.URL, jwksetHTTPStorageOptions) diff --git a/router/pkg/authentication/validation_store.go b/router/pkg/authentication/validation_store.go index d447c26710..62698167c8 100644 --- a/router/pkg/authentication/validation_store.go +++ b/router/pkg/authentication/validation_store.go @@ -12,9 +12,10 @@ import ( var _ jwkset.Storage = (*validationStore)(nil) type validationStore struct { - logger *zap.Logger - algs map[string]struct{} - inner jwkset.Storage + logger *zap.Logger + algs map[string]struct{} + inner jwkset.Storage + allowEmptyAlgorithm bool } var supportedAlgorithms = map[string]struct{}{ @@ -33,7 +34,7 @@ var supportedAlgorithms = map[string]struct{}{ "EdDSA": {}, } -func NewValidationStore(logger *zap.Logger, inner jwkset.Storage, algs []string) jwkset.Storage { +func NewValidationStore(logger *zap.Logger, inner jwkset.Storage, algs []string, allowEmptyAlgorithm bool) jwkset.Storage { if inner == nil { inner = jwkset.NewMemoryStorage() } @@ -45,9 +46,10 @@ func NewValidationStore(logger *zap.Logger, inner jwkset.Storage, algs []string) algSet := make(map[string]struct{}, len(algs)) store := &validationStore{ - logger: logger, - inner: inner, - algs: supportedAlgorithms, + logger: logger, + inner: inner, + algs: supportedAlgorithms, + allowEmptyAlgorithm: allowEmptyAlgorithm, } if len(algs) == 0 { @@ -76,12 +78,11 @@ func (v *validationStore) KeyRead(ctx context.Context, keyID string) (jwkset.JWK return key, err } - m := key.Marshal() - if _, ok := v.algs[m.ALG.String()]; ok { - return key, nil + if fKey, ok := v.getFilteredKey(key); ok { + return fKey, nil } - return jwkset.JWK{}, fmt.Errorf("key with ID %q has an unsupported algorithm %s", keyID, m.ALG.String()) + return jwkset.JWK{}, fmt.Errorf("key with ID %q has an unsupported algorithm %s", keyID, key.Marshal().ALG.String()) } func (v *validationStore) KeyReadAll(ctx context.Context) ([]jwkset.JWK, error) { @@ -93,9 +94,8 @@ func (v *validationStore) KeyReadAll(ctx context.Context) ([]jwkset.JWK, error) filter := make([]jwkset.JWK, 0, len(keys)) for _, k := range keys { - m := k.Marshal() - if _, ok := v.algs[m.ALG.String()]; ok { - filter = append(filter, k) + if fKey, ok := v.getFilteredKey(k); ok { + filter = append(filter, fKey) } } @@ -105,20 +105,19 @@ func (v *validationStore) KeyReadAll(ctx context.Context) ([]jwkset.JWK, error) func (v *validationStore) KeyReplaceAll(ctx context.Context, given []jwkset.JWK) error { filtered := make([]jwkset.JWK, 0) for _, k := range given { - m := k.Marshal() - if _, ok := v.algs[m.ALG.String()]; ok { - filtered = append(filtered, k) + if fKey, ok := v.getFilteredKey(k); ok { + filtered = append(filtered, fKey) } } return v.inner.KeyReplaceAll(ctx, filtered) } func (v *validationStore) KeyWrite(ctx context.Context, jwk jwkset.JWK) error { - jwkMarshal := jwk.Marshal() - if _, ok := v.algs[jwkMarshal.ALG.String()]; !ok { + if _, ok := v.getFilteredKey(jwk); !ok { // We should not return an error here. If JWKS are configured for multiple applications, we should only add the // supported keys to the token decoder store and not prevent the refresh entirely. // In case we are receiving a key with an unsupported algorithm we log a warning instead. + jwkMarshal := jwk.Marshal() v.logger.Warn("Skipping key with unsupported algorithm", zap.String("keyID", jwkMarshal.KID), zap.String("algorithm", jwkMarshal.ALG.String())) return nil } @@ -149,3 +148,18 @@ func (v *validationStore) Marshal(ctx context.Context) (jwkset.JWKSMarshal, erro func (v *validationStore) MarshalWithOptions(ctx context.Context, marshalOptions jwkset.JWKMarshalOptions, validationOptions jwkset.JWKValidateOptions) (jwkset.JWKSMarshal, error) { return v.inner.MarshalWithOptions(ctx, marshalOptions, validationOptions) } + +func (v *validationStore) getFilteredKey(k jwkset.JWK) (jwkset.JWK, bool) { + algString := k.Marshal().ALG.String() + + // If we allow empty algorithm, we accept JWK without an algorithm + // This is algorithm is actually optional according to the RFC + if algString == "" && v.allowEmptyAlgorithm { + return k, true + } + if _, ok := v.algs[algString]; ok { + return k, true + } + + return jwkset.JWK{}, false +} diff --git a/router/pkg/authentication/validation_store_test.go b/router/pkg/authentication/validation_store_test.go new file mode 100644 index 0000000000..a3e21f5a06 --- /dev/null +++ b/router/pkg/authentication/validation_store_test.go @@ -0,0 +1,178 @@ +package authentication + +import ( + "context" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "strings" + "testing" + + "github.com/MicahParks/jwkset" + "go.uber.org/zap" +) + +func makeEd25519JWK(t *testing.T, kid string, alg string, setAlg bool) jwkset.JWK { + t.Helper() + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("failed to generate ed25519 key: %v", err) + } + + meta := jwkset.JWKMetadataOptions{KID: kid, USE: jwkset.UseSig} + if setAlg { + meta.ALG = jwkset.ALG(alg) + } + options := jwkset.JWKOptions{Metadata: meta} + + j, err := jwkset.NewJWKFromKey(priv, options) + if err != nil { + t.Fatalf("failed to create JWK: %v", err) + } + return j +} + +func makeES256JWK(t *testing.T, kid string) jwkset.JWK { + t.Helper() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("failed to generate ecdsa p256 key: %v", err) + } + meta := jwkset.JWKMetadataOptions{KID: kid, USE: jwkset.UseSig, ALG: jwkset.AlgES256} + options := jwkset.JWKOptions{Metadata: meta} + j, err := jwkset.NewJWKFromKey(priv, options) + if err != nil { + t.Fatalf("failed to create JWK: %v", err) + } + return j +} + +func TestKeyWriteAndRead_SupportedAlg(t *testing.T) { + ctx := context.Background() + store := NewValidationStore(zap.NewNop(), nil, nil, false) + + good := makeEd25519JWK(t, "good-ed", "EdDSA", true) + if err := store.KeyWrite(ctx, good); err != nil { + t.Fatalf("KeyWrite failed: %v", err) + } + + got, err := store.KeyRead(ctx, "good-ed") + if err != nil { + t.Fatalf("KeyRead failed: %v", err) + } + if got.Marshal().KID != "good-ed" || got.Marshal().ALG.String() != "EdDSA" { + t.Fatalf("unexpected key metadata: kid=%q alg=%q", got.Marshal().KID, got.Marshal().ALG.String()) + } +} + +func TestKeyWrite_SkipsUnsupportedAlg(t *testing.T) { + ctx := context.Background() + store := NewValidationStore(zap.NewNop(), nil, nil, false) + + bad := makeEd25519JWK(t, "bad", "FOO", true) + if err := store.KeyWrite(ctx, bad); err != nil { + t.Fatalf("KeyWrite returned error for unsupported alg (should skip without error): %v", err) + } + + if _, err := store.KeyRead(ctx, "bad"); err == nil { + t.Fatalf("expected KeyRead to fail for skipped unsupported key") + } + + all, err := store.KeyReadAll(ctx) + if err != nil { + t.Fatalf("KeyReadAll failed: %v", err) + } + if len(all) != 0 { + t.Fatalf("expected 0 keys, got %d", len(all)) + } +} + +func TestAllowEmptyAlgorithm(t *testing.T) { + ctx := context.Background() + + // allowEmptyAlgorithm = true accepts keys without ALG + storeAllow := NewValidationStore(zap.NewNop(), nil, nil, true) + noAlg := makeEd25519JWK(t, "noalg", "", false) + if err := storeAllow.KeyWrite(ctx, noAlg); err != nil { + t.Fatalf("KeyWrite failed: %v", err) + } + if _, err := storeAllow.KeyRead(ctx, "noalg"); err != nil { + t.Fatalf("expected KeyRead to succeed for empty ALG when allowed, got: %v", err) + } + + // allowEmptyAlgorithm = false skips keys without ALG + storeDeny := NewValidationStore(zap.NewNop(), nil, nil, false) + if err := storeDeny.KeyWrite(ctx, noAlg); err != nil { + t.Fatalf("KeyWrite returned error for empty ALG (should skip without error): %v", err) + } + if _, err := storeDeny.KeyRead(ctx, "noalg"); err == nil { + t.Fatalf("expected KeyRead to fail for skipped empty-ALG key") + } +} + +func TestRestrictAlgorithmsList(t *testing.T) { + ctx := context.Background() + store := NewValidationStore(zap.NewNop(), nil, []string{"ES256"}, false) + + // EdDSA should be rejected when only ES256 is allowed + ed := makeEd25519JWK(t, "ed", "EdDSA", true) + if err := store.KeyWrite(ctx, ed); err != nil { + t.Fatalf("KeyWrite returned error while skipping disallowed alg: %v", err) + } + if _, err := store.KeyRead(ctx, "ed"); err == nil { + t.Fatalf("expected KeyRead to fail for disallowed alg EdDSA") + } + + // ES256 should be accepted + es := makeES256JWK(t, "es256") + if err := store.KeyWrite(ctx, es); err != nil { + t.Fatalf("KeyWrite failed for ES256: %v", err) + } + if _, err := store.KeyRead(ctx, "es256"); err != nil { + t.Fatalf("expected KeyRead to succeed for ES256, got: %v", err) + } +} + +func TestKeyReplaceAll_Filters(t *testing.T) { + ctx := context.Background() + store := NewValidationStore(zap.NewNop(), nil, nil, false) + + good := makeEd25519JWK(t, "good", "EdDSA", true) + bad := makeEd25519JWK(t, "bad", "FOO", true) + noAlg := makeEd25519JWK(t, "noalg2", "", false) + + if err := store.KeyReplaceAll(ctx, []jwkset.JWK{good, bad, noAlg}); err != nil { + t.Fatalf("KeyReplaceAll failed: %v", err) + } + keys, err := store.KeyReadAll(ctx) + if err != nil { + t.Fatalf("KeyReadAll failed: %v", err) + } + if len(keys) != 1 || keys[0].Marshal().KID != "good" { + t.Fatalf("expected only the supported key to remain, got %d keys (first kid=%q)", len(keys), func() string { + if len(keys) > 0 { + return keys[0].Marshal().KID + } + return "" + }()) + } +} + +func TestKeyRead_ErrorsWhenInnerHasUnsupportedKey(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + bad := makeEd25519JWK(t, "bad-inner", "FOO", true) + if err := inner.KeyWrite(ctx, bad); err != nil { + t.Fatalf("failed to write to inner storage: %v", err) + } + + store := NewValidationStore(zap.NewNop(), inner, nil, false) + _, err := store.KeyRead(ctx, "bad-inner") + if err == nil { + t.Fatalf("expected error for unsupported algorithm in inner storage") + } + if !strings.Contains(err.Error(), "unsupported algorithm") { + t.Fatalf("expected error to mention unsupported algorithm, got: %v", err) + } +} diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 57591de31e..fe3fd68c52 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -478,7 +478,8 @@ type JWKSConfiguration struct { KeyId string `yaml:"header_key_id"` // Common - Audiences []string `yaml:"audiences"` + Audiences []string `yaml:"audiences"` + AllowEmptyAlgorithm bool `yaml:"allow_empty_algorithm" envDefault:"false"` } type RefreshUnknownKID struct { From c4aa09e2d429758cb2e743b1a870d75d19e3f757 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 17:05:42 +0530 Subject: [PATCH 13/45] fix: current --- router-tests/authentication_test.go | 26 ++++++++++++++++++++++++++ router-tests/jwks/crypto.go | 6 +++--- router-tests/jwks/jwks.go | 14 +++++++++++++- router-tests/utils.go | 13 +++++++++++-- 4 files changed, 53 insertions(+), 6 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 326b84749a..ad21881a58 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -2831,6 +2831,32 @@ func TestAudienceValidation(t *testing.T) { require.Equal(t, employeesExpectedData, string(data)) }) }) + + t.Run("valid token with empty algorithm in JWKS", func(t *testing.T) { + t.Parallel() + + authenticators, authServer := ConfigureAuthWithOpts(t, ConfigureAuthOpts{AllowEmptyAlgorithm: true}) + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, false)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + token, err := authServer.Token(nil) + require.NoError(t, err) + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) + }) + }) } func toJWKSConfig(url string, refresh time.Duration, allowedAlgorithms ...string) authentication.JWKSConfig { diff --git a/router-tests/jwks/crypto.go b/router-tests/jwks/crypto.go index 0b157a8440..61fa0d5d04 100644 --- a/router-tests/jwks/crypto.go +++ b/router-tests/jwks/crypto.go @@ -43,7 +43,7 @@ func (b *baseCrypto) MarshalJWK() (jwkset.JWK, error) { } meta := jwkset.JWKMetadataOptions{ - ALG: b.alg, + //ALG: b.alg, KID: b.kID, USE: jwkset.UseSig, } @@ -76,8 +76,8 @@ func NewRSACrypto(kID string, alg jwkset.ALG, size int) (Crypto, error) { return &rsaCrypto{ baseCrypto: baseCrypto{ - pk: pk, - alg: alg, + pk: pk, + //alg: alg, kID: kID, }, }, nil diff --git a/router-tests/jwks/jwks.go b/router-tests/jwks/jwks.go index 6b77a76812..616bee91fb 100644 --- a/router-tests/jwks/jwks.go +++ b/router-tests/jwks/jwks.go @@ -42,7 +42,8 @@ func (s *Server) Token(claims map[string]any) (string, error) { } for kid, pr := range s.providers { - token := jwt.NewWithClaims(pr.SigningMethod(), jwt.MapClaims(claims)) + method := jwt.GetSigningMethod(jwt.SigningMethodRS256.Alg()) + token := NewWithClaims(method, jwt.MapClaims(claims)) token.Header[jwkset.HeaderKID] = kid return token.SignedString(pr.PrivateKey()) } @@ -50,6 +51,17 @@ func (s *Server) Token(claims map[string]any) (string, error) { return "", jwt.ErrInvalidKey } +func NewWithClaims(method jwt.SigningMethod, claims jwt.Claims, opts ...jwt.TokenOption) *jwt.Token { + return &jwt.Token{ + Header: map[string]interface{}{ + "typ": "JWT", + "alg": jwt.SigningMethodRS256.Alg(), + }, + Claims: claims, + Method: method, + } +} + func (s *Server) TokenForKID(kid string, claims map[string]any, useInvalidKID bool) (string, error) { provider, ok := s.providers[kid] if useInvalidKID { diff --git a/router-tests/utils.go b/router-tests/utils.go index 09bf9b1fcc..53386a5a46 100644 --- a/router-tests/utils.go +++ b/router-tests/utils.go @@ -42,14 +42,23 @@ func RequireSpanWithName(t *testing.T, exporter *tracetest2.InMemoryExporter, na return testSpan } +type ConfigureAuthOpts struct { + AllowEmptyAlgorithm bool +} + func ConfigureAuth(t *testing.T) ([]authentication.Authenticator, *jwks.Server) { + return ConfigureAuthWithOpts(t, ConfigureAuthOpts{}) +} + +func ConfigureAuthWithOpts(t *testing.T, opts ConfigureAuthOpts) ([]authentication.Authenticator, *jwks.Server) { authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ { - URL: authServer.JWKSURL(), - RefreshInterval: time.Second * 5, + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, + AllowEmptyAlgorithm: opts.AllowEmptyAlgorithm, }, }) From 62e7023430f903ac2fb0f1202a784a68cbae3d2f Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 17:36:53 +0530 Subject: [PATCH 14/45] fix: updates --- router-tests/authentication_test.go | 21 ++- router-tests/jwks/crypto.go | 9 +- router-tests/jwks/jwks.go | 24 ++- router-tests/utils.go | 9 +- .../authentication/validation_store_test.go | 178 ------------------ 5 files changed, 49 insertions(+), 192 deletions(-) delete mode 100644 router/pkg/authentication/validation_store_test.go diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index ad21881a58..b9279d69d1 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -2835,14 +2835,31 @@ func TestAudienceValidation(t *testing.T) { t.Run("valid token with empty algorithm in JWKS", func(t *testing.T) { t.Parallel() - authenticators, authServer := ConfigureAuthWithOpts(t, ConfigureAuthOpts{AllowEmptyAlgorithm: true}) + rsaCrypto, err := jwks.NewRSACrypto("", "", 2048) + if err != nil { + t.Fatalf("Failed to create an RSA crypto provider.\nError: %s", err) + } + authServer, err := jwks.NewServerWithCrypto(t, rsaCrypto) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, + AllowEmptyAlgorithm: true, + }, + }) + testenv.Run(t, &testenv.Config{ RouterOptions: []core.Option{ core.WithAccessController(core.NewAccessController(authenticators, false)), }, }, func(t *testing.T, xEnv *testenv.Environment) { // Operations with a token should succeed - token, err := authServer.Token(nil) + token, err := authServer.TokenWithOpts(nil, jwks.TokenOpts{ + AlgOverride: string(jwkset.AlgRS256), + }) require.NoError(t, err) header := http.Header{ "Authorization": []string{"Bearer " + token}, diff --git a/router-tests/jwks/crypto.go b/router-tests/jwks/crypto.go index 61fa0d5d04..403bd5af39 100644 --- a/router-tests/jwks/crypto.go +++ b/router-tests/jwks/crypto.go @@ -43,11 +43,14 @@ func (b *baseCrypto) MarshalJWK() (jwkset.JWK, error) { } meta := jwkset.JWKMetadataOptions{ - //ALG: b.alg, KID: b.kID, USE: jwkset.UseSig, } + if b.alg != "" { + meta.ALG = b.alg + } + options := jwkset.JWKOptions{ Marshal: marshalOptions, Metadata: meta, @@ -76,8 +79,8 @@ func NewRSACrypto(kID string, alg jwkset.ALG, size int) (Crypto, error) { return &rsaCrypto{ baseCrypto: baseCrypto{ - pk: pk, - //alg: alg, + pk: pk, + alg: alg, kID: kID, }, }, nil diff --git a/router-tests/jwks/jwks.go b/router-tests/jwks/jwks.go index 616bee91fb..14e311a6c2 100644 --- a/router-tests/jwks/jwks.go +++ b/router-tests/jwks/jwks.go @@ -36,14 +36,34 @@ func (s *Server) Close() { s.httpServer.Close() } +type TokenOpts struct { + AlgOverride string +} + func (s *Server) Token(claims map[string]any) (string, error) { + return s.TokenWithOpts(claims, TokenOpts{AlgOverride: ""}) +} + +func (s *Server) TokenWithOpts(claims map[string]any, tokenOpts TokenOpts) (string, error) { if len(s.providers) == 0 { return "", jwt.ErrInvalidKey } for kid, pr := range s.providers { - method := jwt.GetSigningMethod(jwt.SigningMethodRS256.Alg()) - token := NewWithClaims(method, jwt.MapClaims(claims)) + var token *jwt.Token + if tokenOpts.AlgOverride != "" { + token = &jwt.Token{ + Header: map[string]interface{}{ + "typ": "JWT", + "alg": tokenOpts.AlgOverride, + }, + Claims: jwt.MapClaims(claims), + Method: jwt.GetSigningMethod(tokenOpts.AlgOverride), + } + } else { + token = jwt.NewWithClaims(pr.SigningMethod(), jwt.MapClaims(claims)) + } + token.Header[jwkset.HeaderKID] = kid return token.SignedString(pr.PrivateKey()) } diff --git a/router-tests/utils.go b/router-tests/utils.go index 53386a5a46..825bcdb19e 100644 --- a/router-tests/utils.go +++ b/router-tests/utils.go @@ -47,18 +47,13 @@ type ConfigureAuthOpts struct { } func ConfigureAuth(t *testing.T) ([]authentication.Authenticator, *jwks.Server) { - return ConfigureAuthWithOpts(t, ConfigureAuthOpts{}) -} - -func ConfigureAuthWithOpts(t *testing.T, opts ConfigureAuthOpts) ([]authentication.Authenticator, *jwks.Server) { authServer, err := jwks.NewServer(t) require.NoError(t, err) t.Cleanup(authServer.Close) authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ { - URL: authServer.JWKSURL(), - RefreshInterval: time.Second * 5, - AllowEmptyAlgorithm: opts.AllowEmptyAlgorithm, + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, }, }) diff --git a/router/pkg/authentication/validation_store_test.go b/router/pkg/authentication/validation_store_test.go deleted file mode 100644 index a3e21f5a06..0000000000 --- a/router/pkg/authentication/validation_store_test.go +++ /dev/null @@ -1,178 +0,0 @@ -package authentication - -import ( - "context" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/elliptic" - "crypto/rand" - "strings" - "testing" - - "github.com/MicahParks/jwkset" - "go.uber.org/zap" -) - -func makeEd25519JWK(t *testing.T, kid string, alg string, setAlg bool) jwkset.JWK { - t.Helper() - _, priv, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatalf("failed to generate ed25519 key: %v", err) - } - - meta := jwkset.JWKMetadataOptions{KID: kid, USE: jwkset.UseSig} - if setAlg { - meta.ALG = jwkset.ALG(alg) - } - options := jwkset.JWKOptions{Metadata: meta} - - j, err := jwkset.NewJWKFromKey(priv, options) - if err != nil { - t.Fatalf("failed to create JWK: %v", err) - } - return j -} - -func makeES256JWK(t *testing.T, kid string) jwkset.JWK { - t.Helper() - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - t.Fatalf("failed to generate ecdsa p256 key: %v", err) - } - meta := jwkset.JWKMetadataOptions{KID: kid, USE: jwkset.UseSig, ALG: jwkset.AlgES256} - options := jwkset.JWKOptions{Metadata: meta} - j, err := jwkset.NewJWKFromKey(priv, options) - if err != nil { - t.Fatalf("failed to create JWK: %v", err) - } - return j -} - -func TestKeyWriteAndRead_SupportedAlg(t *testing.T) { - ctx := context.Background() - store := NewValidationStore(zap.NewNop(), nil, nil, false) - - good := makeEd25519JWK(t, "good-ed", "EdDSA", true) - if err := store.KeyWrite(ctx, good); err != nil { - t.Fatalf("KeyWrite failed: %v", err) - } - - got, err := store.KeyRead(ctx, "good-ed") - if err != nil { - t.Fatalf("KeyRead failed: %v", err) - } - if got.Marshal().KID != "good-ed" || got.Marshal().ALG.String() != "EdDSA" { - t.Fatalf("unexpected key metadata: kid=%q alg=%q", got.Marshal().KID, got.Marshal().ALG.String()) - } -} - -func TestKeyWrite_SkipsUnsupportedAlg(t *testing.T) { - ctx := context.Background() - store := NewValidationStore(zap.NewNop(), nil, nil, false) - - bad := makeEd25519JWK(t, "bad", "FOO", true) - if err := store.KeyWrite(ctx, bad); err != nil { - t.Fatalf("KeyWrite returned error for unsupported alg (should skip without error): %v", err) - } - - if _, err := store.KeyRead(ctx, "bad"); err == nil { - t.Fatalf("expected KeyRead to fail for skipped unsupported key") - } - - all, err := store.KeyReadAll(ctx) - if err != nil { - t.Fatalf("KeyReadAll failed: %v", err) - } - if len(all) != 0 { - t.Fatalf("expected 0 keys, got %d", len(all)) - } -} - -func TestAllowEmptyAlgorithm(t *testing.T) { - ctx := context.Background() - - // allowEmptyAlgorithm = true accepts keys without ALG - storeAllow := NewValidationStore(zap.NewNop(), nil, nil, true) - noAlg := makeEd25519JWK(t, "noalg", "", false) - if err := storeAllow.KeyWrite(ctx, noAlg); err != nil { - t.Fatalf("KeyWrite failed: %v", err) - } - if _, err := storeAllow.KeyRead(ctx, "noalg"); err != nil { - t.Fatalf("expected KeyRead to succeed for empty ALG when allowed, got: %v", err) - } - - // allowEmptyAlgorithm = false skips keys without ALG - storeDeny := NewValidationStore(zap.NewNop(), nil, nil, false) - if err := storeDeny.KeyWrite(ctx, noAlg); err != nil { - t.Fatalf("KeyWrite returned error for empty ALG (should skip without error): %v", err) - } - if _, err := storeDeny.KeyRead(ctx, "noalg"); err == nil { - t.Fatalf("expected KeyRead to fail for skipped empty-ALG key") - } -} - -func TestRestrictAlgorithmsList(t *testing.T) { - ctx := context.Background() - store := NewValidationStore(zap.NewNop(), nil, []string{"ES256"}, false) - - // EdDSA should be rejected when only ES256 is allowed - ed := makeEd25519JWK(t, "ed", "EdDSA", true) - if err := store.KeyWrite(ctx, ed); err != nil { - t.Fatalf("KeyWrite returned error while skipping disallowed alg: %v", err) - } - if _, err := store.KeyRead(ctx, "ed"); err == nil { - t.Fatalf("expected KeyRead to fail for disallowed alg EdDSA") - } - - // ES256 should be accepted - es := makeES256JWK(t, "es256") - if err := store.KeyWrite(ctx, es); err != nil { - t.Fatalf("KeyWrite failed for ES256: %v", err) - } - if _, err := store.KeyRead(ctx, "es256"); err != nil { - t.Fatalf("expected KeyRead to succeed for ES256, got: %v", err) - } -} - -func TestKeyReplaceAll_Filters(t *testing.T) { - ctx := context.Background() - store := NewValidationStore(zap.NewNop(), nil, nil, false) - - good := makeEd25519JWK(t, "good", "EdDSA", true) - bad := makeEd25519JWK(t, "bad", "FOO", true) - noAlg := makeEd25519JWK(t, "noalg2", "", false) - - if err := store.KeyReplaceAll(ctx, []jwkset.JWK{good, bad, noAlg}); err != nil { - t.Fatalf("KeyReplaceAll failed: %v", err) - } - keys, err := store.KeyReadAll(ctx) - if err != nil { - t.Fatalf("KeyReadAll failed: %v", err) - } - if len(keys) != 1 || keys[0].Marshal().KID != "good" { - t.Fatalf("expected only the supported key to remain, got %d keys (first kid=%q)", len(keys), func() string { - if len(keys) > 0 { - return keys[0].Marshal().KID - } - return "" - }()) - } -} - -func TestKeyRead_ErrorsWhenInnerHasUnsupportedKey(t *testing.T) { - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - bad := makeEd25519JWK(t, "bad-inner", "FOO", true) - if err := inner.KeyWrite(ctx, bad); err != nil { - t.Fatalf("failed to write to inner storage: %v", err) - } - - store := NewValidationStore(zap.NewNop(), inner, nil, false) - _, err := store.KeyRead(ctx, "bad-inner") - if err == nil { - t.Fatalf("expected error for unsupported algorithm in inner storage") - } - if !strings.Contains(err.Error(), "unsupported algorithm") { - t.Fatalf("expected error to mention unsupported algorithm, got: %v", err) - } -} From 7ed09a88abd1f15d2deda1d43a22ffde7e279861 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 18:15:47 +0530 Subject: [PATCH 15/45] fix: cleanup --- router-tests/jwks/jwks.go | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/router-tests/jwks/jwks.go b/router-tests/jwks/jwks.go index 14e311a6c2..5eb8fc8646 100644 --- a/router-tests/jwks/jwks.go +++ b/router-tests/jwks/jwks.go @@ -71,17 +71,6 @@ func (s *Server) TokenWithOpts(claims map[string]any, tokenOpts TokenOpts) (stri return "", jwt.ErrInvalidKey } -func NewWithClaims(method jwt.SigningMethod, claims jwt.Claims, opts ...jwt.TokenOption) *jwt.Token { - return &jwt.Token{ - Header: map[string]interface{}{ - "typ": "JWT", - "alg": jwt.SigningMethodRS256.Alg(), - }, - Claims: claims, - Method: method, - } -} - func (s *Server) TokenForKID(kid string, claims map[string]any, useInvalidKID bool) (string, error) { provider, ok := s.providers[kid] if useInvalidKID { From 2ac59991bc09046931894e5c22aa9c0756ac532e Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 18:23:04 +0530 Subject: [PATCH 16/45] fix: add schema --- router/pkg/config/config.schema.json | 4 ++++ router/pkg/config/fixtures/full.yaml | 1 + router/pkg/config/testdata/config_full.json | 9 ++++++--- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 570d5f8086..e3b5ab2456 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1687,6 +1687,10 @@ "type": "string", "description": "The KID header of the JWK token created using the secret" }, + "allow_empty_algorithm": { + "type": "boolean", + "description": "This attribute can be enabled to allow for the JWK to contain keys with empty algorithms" + }, "algorithms": { "type": "array", "description": "The allowed algorithms for the keys that are retrieved from the JWKs. An empty list means that all algorithms are allowed.", diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index a43691cc12..7121f81639 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -272,6 +272,7 @@ authentication: - url: 'https://example.com/.well-known/jwks.json' refresh_interval: 1m algorithms: ['RS256'] + allow_empty_algorithm: true - url: 'https://example.com/.well-known/jwks2.json' refresh_interval: 2m algorithms: ['RS256', 'ES256'] diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index ba295c3935..ef95998140 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -489,7 +489,8 @@ "Secret": "", "Algorithm": "", "KeyId": "", - "Audiences": null + "Audiences": null, + "AllowEmptyAlgorithm": true }, { "URL": "https://example.com/.well-known/jwks2.json", @@ -507,7 +508,8 @@ "Secret": "", "Algorithm": "", "KeyId": "", - "Audiences": null + "Audiences": null, + "AllowEmptyAlgorithm": false }, { "URL": "https://example.com/.well-known/jwks3.json", @@ -522,7 +524,8 @@ "Secret": "", "Algorithm": "", "KeyId": "", - "Audiences": null + "Audiences": null, + "AllowEmptyAlgorithm": false } ], "HeaderName": "Authorization", From 47816e23d9a32a800f3400bd5fdcd81bd638dc1c Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 21:25:31 +0530 Subject: [PATCH 17/45] fix: refactoring --- router-tests/authentication_test.go | 47 +++++++++++++++++-- .../pkg/authentication/jwks_token_decoder.go | 47 ++++++++++++++++--- router/pkg/authentication/validation_store.go | 14 ++++-- router/pkg/config/config.schema.json | 44 ++++++++++++----- 4 files changed, 127 insertions(+), 25 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index b9279d69d1..f38bf054be 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -2836,9 +2836,8 @@ func TestAudienceValidation(t *testing.T) { t.Parallel() rsaCrypto, err := jwks.NewRSACrypto("", "", 2048) - if err != nil { - t.Fatalf("Failed to create an RSA crypto provider.\nError: %s", err) - } + require.NoError(t, err) + authServer, err := jwks.NewServerWithCrypto(t, rsaCrypto) require.NoError(t, err) t.Cleanup(authServer.Close) @@ -2874,6 +2873,48 @@ func TestAudienceValidation(t *testing.T) { require.Equal(t, employeesExpectedData, string(data)) }) }) + + t.Run("verify blocking invalid specified algorithm even though token is valid", func(t *testing.T) { + t.Parallel() + + rsaCrypto, err := jwks.NewRSACrypto("", "", 2048) + require.NoError(t, err) + + authServer, err := jwks.NewServerWithCrypto(t, rsaCrypto) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + allowedAlgorithm := jwkset.AlgRS256 + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, + AllowedAlgorithms: []string{string(allowedAlgorithm)}, + AllowEmptyAlgorithm: true, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, false)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Fail with RS512 + token2, err := authServer.TokenWithOpts(nil, jwks.TokenOpts{ + AlgOverride: string(jwkset.AlgRS512), + }) + require.NoError(t, err) + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", http.Header{ + "Authorization": []string{"Bearer " + token2}, + }, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { + _ = res2.Body.Close() + }() + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + }) + }) } func toJWKSConfig(url string, refresh time.Duration, allowedAlgorithms ...string) authentication.JWKSConfig { diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index fb4cbe6de1..88ecd1e3a5 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "slices" "time" "golang.org/x/time/rate" @@ -67,9 +68,15 @@ type audKey struct { type audienceSet map[string]struct{} +type keyFuncWithOpts struct { + keyFunc keyfunc.Keyfunc + allowEmptyAlgorithm bool + allowedAlgorithms []string +} + func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKSConfig) (TokenDecoder, error) { audiencesMap := make(map[audKey]audienceSet, len(configs)) - keyFuncMap := make(map[audKey]keyfunc.Keyfunc, len(configs)) + keyFuncMap := make(map[audKey]keyFuncWithOpts, len(configs)) for _, c := range configs { if c.URL != "" { @@ -80,6 +87,8 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS l := logger.With(zap.String("url", c.URL)) + newValidationStore, processedAllowedAlgorithms := NewValidationStore(logger, nil, c.AllowedAlgorithms, c.AllowEmptyAlgorithm) + jwksetHTTPStorageOptions := jwkset.HTTPClientStorageOptions{ Client: newOIDCDiscoveryClient(httpclient.NewRetryableHTTPClient(l)), Ctx: ctx, // Used to end background refresh goroutine. @@ -90,7 +99,7 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS l.Error("Failed to refresh HTTP JWK Set from remote HTTP resource.", zap.Error(err)) }, RefreshInterval: c.RefreshInterval, - Storage: NewValidationStore(logger, nil, c.AllowedAlgorithms, c.AllowEmptyAlgorithm), + Storage: newValidationStore, } store, err := jwkset.NewStorageFromHTTP(c.URL, jwksetHTTPStorageOptions) @@ -117,14 +126,17 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS if err != nil { return nil, err } - keyFuncMap[key] = jwks + keyFuncMap[key] = keyFuncWithOpts{ + keyFunc: jwks, + allowEmptyAlgorithm: c.AllowEmptyAlgorithm, + allowedAlgorithms: processedAllowedAlgorithms, + } } else if c.Secret != "" { key := audKey{kid: c.KeyId} if _, ok := audiencesMap[key]; ok { return nil, fmt.Errorf("duplicate JWK keyid specified found: %s", c.KeyId) } - given := jwkset.NewMemoryStorage() marshalOptions := jwkset.JWKMarshalOptions{ @@ -168,14 +180,35 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS if err != nil { return nil, err } - keyFuncMap[key] = jwks + keyFuncMap[key] = keyFuncWithOpts{ + keyFunc: jwks, + allowedAlgorithms: []string{c.Algorithm}, + } } } keyFuncWrapper := jwt.Keyfunc(func(token *jwt.Token) (any, error) { var errJoin error - for key, keyFunc := range keyFuncMap { - pub, err := keyFunc.Keyfunc(token) + for key, keyFuncAndOpts := range keyFuncMap { + // TODO: We can enable this for non empty cases, though it is a potential breaking change + // if users are using RS512 after specifying RS256 for example, to discuss + if keyFuncAndOpts.allowEmptyAlgorithm { + // We use the same error messages as keyfunc.Keyfunc + algInter, ok := token.Header["alg"] + if !ok { + return nil, fmt.Errorf("%w: could not find alg in JWT header", keyfunc.ErrKeyfunc) + } + alg, ok := algInter.(string) + if !ok { + return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, keyfunc.ErrKeyfunc) + } + // This is a custom validation different from keyfunc.Keyfunc + if !slices.Contains(keyFuncAndOpts.allowedAlgorithms, alg) { + return nil, fmt.Errorf("%w: could not find alg %s in allow list", keyfunc.ErrKeyfunc, alg) + } + } + + pub, err := keyFuncAndOpts.keyFunc.Keyfunc(token) if err != nil { errJoin = errors.Join(errJoin, err) continue diff --git a/router/pkg/authentication/validation_store.go b/router/pkg/authentication/validation_store.go index 62698167c8..081a63ef52 100644 --- a/router/pkg/authentication/validation_store.go +++ b/router/pkg/authentication/validation_store.go @@ -34,7 +34,7 @@ var supportedAlgorithms = map[string]struct{}{ "EdDSA": {}, } -func NewValidationStore(logger *zap.Logger, inner jwkset.Storage, algs []string, allowEmptyAlgorithm bool) jwkset.Storage { +func NewValidationStore(logger *zap.Logger, inner jwkset.Storage, algs []string, allowEmptyAlgorithm bool) (jwkset.Storage, []string) { if inner == nil { inner = jwkset.NewMemoryStorage() } @@ -53,7 +53,7 @@ func NewValidationStore(logger *zap.Logger, inner jwkset.Storage, algs []string, } if len(algs) == 0 { - return store + return store, nil } for _, alg := range algs { @@ -65,7 +65,15 @@ func NewValidationStore(logger *zap.Logger, inner jwkset.Storage, algs []string, } store.algs = algSet - return store + return store, store.getSupportedAlgorithms() +} + +func (v *validationStore) getSupportedAlgorithms() []string { + algs := make([]string, 0, len(v.algs)) + for alg := range v.algs { + algs = append(algs, alg) + } + return algs } func (v *validationStore) KeyDelete(ctx context.Context, keyID string) (ok bool, err error) { diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index e3b5ab2456..6a392d987b 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1758,20 +1758,37 @@ }, "oneOf": [ { - "required": ["url"], - "not": { - "anyOf": [ - { - "required": ["secret"] - }, - { - "required": ["symmetric_algorithm"] + "allOf": [ + { + "required": ["url"], + "not": { + "anyOf": [ + { + "required": ["secret"] + }, + { + "required": ["symmetric_algorithm"] + }, + { + "required": ["header_key_id"] + } + ] + } + }, + { + "if": { + "required": ["allow_empty_algorithm"] }, - { - "required": ["header_key_id"] + "then": { + "required": ["algorithms"], + "properties": { + "algorithms": { + "minItems": 1 + } + } } - ] - } + } + ] }, { "required": ["secret", "symmetric_algorithm", "header_key_id"], @@ -1785,6 +1802,9 @@ }, { "required": ["refresh_interval"] + }, + { + "required": ["allow_empty_algorithm"] } ] } From 3974969b5173144506c7ff4c3a137e4c58733779 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 21:35:38 +0530 Subject: [PATCH 18/45] fix: refactor comment --- router/pkg/authentication/jwks_token_decoder.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index 88ecd1e3a5..416d42ee4f 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -190,10 +190,11 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS keyFuncWrapper := jwt.Keyfunc(func(token *jwt.Token) (any, error) { var errJoin error for key, keyFuncAndOpts := range keyFuncMap { - // TODO: We can enable this for non empty cases, though it is a potential breaking change - // if users are using RS512 after specifying RS256 for example, to discuss + // When an algorithm is actually provided in the jwks the current keyfunc will validate the + // jwts algorithm with it. But when no algorithm is provided (alg: none or missing alg) + // the default keyfunc will not validate the algorithm as it has nothing to cross check. if keyFuncAndOpts.allowEmptyAlgorithm { - // We use the same error messages as keyfunc.Keyfunc + // We use the same error messages as keyfunc.Keyfunc for consistency algInter, ok := token.Header["alg"] if !ok { return nil, fmt.Errorf("%w: could not find alg in JWT header", keyfunc.ErrKeyfunc) @@ -202,6 +203,7 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS if !ok { return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, keyfunc.ErrKeyfunc) } + // This is a custom validation different from keyfunc.Keyfunc if !slices.Contains(keyFuncAndOpts.allowedAlgorithms, alg) { return nil, fmt.Errorf("%w: could not find alg %s in allow list", keyfunc.ErrKeyfunc, alg) From 3b93c0572d39dd5d0c9977d70bfeadda38918ec4 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 21:49:40 +0530 Subject: [PATCH 19/45] fix: bug resolving --- router/pkg/authentication/validation_store.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/pkg/authentication/validation_store.go b/router/pkg/authentication/validation_store.go index 081a63ef52..96f174b246 100644 --- a/router/pkg/authentication/validation_store.go +++ b/router/pkg/authentication/validation_store.go @@ -53,7 +53,7 @@ func NewValidationStore(logger *zap.Logger, inner jwkset.Storage, algs []string, } if len(algs) == 0 { - return store, nil + return store, store.getSupportedAlgorithms() } for _, alg := range algs { From fb6a32edf95a7c832cdd9aeb4f04b5eef009db3f Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 22:06:53 +0530 Subject: [PATCH 20/45] fix: review comments --- router-tests/jwks/jwks.go | 11 ++++------- router/pkg/authentication/jwks_token_decoder.go | 9 ++++++--- router/pkg/config/config.schema.json | 5 ++++- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/router-tests/jwks/jwks.go b/router-tests/jwks/jwks.go index 5eb8fc8646..5b329ae475 100644 --- a/router-tests/jwks/jwks.go +++ b/router-tests/jwks/jwks.go @@ -52,14 +52,11 @@ func (s *Server) TokenWithOpts(claims map[string]any, tokenOpts TokenOpts) (stri for kid, pr := range s.providers { var token *jwt.Token if tokenOpts.AlgOverride != "" { - token = &jwt.Token{ - Header: map[string]interface{}{ - "typ": "JWT", - "alg": tokenOpts.AlgOverride, - }, - Claims: jwt.MapClaims(claims), - Method: jwt.GetSigningMethod(tokenOpts.AlgOverride), + method := jwt.GetSigningMethod(tokenOpts.AlgOverride) + if method == nil { + return "", fmt.Errorf("unsupported signing method: %s", tokenOpts.AlgOverride) } + token = jwt.NewWithClaims(method, jwt.MapClaims(claims)) } else { token = jwt.NewWithClaims(pr.SigningMethod(), jwt.MapClaims(claims)) } diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index 416d42ee4f..09e5d6b2ad 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -197,16 +197,19 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS // We use the same error messages as keyfunc.Keyfunc for consistency algInter, ok := token.Header["alg"] if !ok { - return nil, fmt.Errorf("%w: could not find alg in JWT header", keyfunc.ErrKeyfunc) + errJoin = errors.Join(errJoin, fmt.Errorf("%w: could not find alg in JWT header", keyfunc.ErrKeyfunc)) + continue } alg, ok := algInter.(string) if !ok { - return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, keyfunc.ErrKeyfunc) + errJoin = errors.Join(errJoin, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, keyfunc.ErrKeyfunc)) + continue } // This is a custom validation different from keyfunc.Keyfunc if !slices.Contains(keyFuncAndOpts.allowedAlgorithms, alg) { - return nil, fmt.Errorf("%w: could not find alg %s in allow list", keyfunc.ErrKeyfunc, alg) + errJoin = errors.Join(errJoin, fmt.Errorf("%w: could not find alg %s in allow list", keyfunc.ErrKeyfunc, alg)) + continue } } diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 6a392d987b..87e924d0ee 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1777,7 +1777,10 @@ }, { "if": { - "required": ["allow_empty_algorithm"] + "required": ["allow_empty_algorithm"], + "properties": { + "allow_empty_algorithm": { "const": true } + } }, "then": { "required": ["algorithms"], From cea65d489b71109e9d079ba4c10a664b7360d877 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 17 Sep 2025 22:36:40 +0530 Subject: [PATCH 21/45] fix: audience --- router/pkg/authentication/jwks_token_decoder.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index 09e5d6b2ad..b358978363 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -223,10 +223,12 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS if len(expectedAudiences) > 0 { tokenAudiences, err := token.Claims.GetAudience() if err != nil { - return nil, fmt.Errorf("could not get audiences from token claims: %w", err) + errJoin = errors.Join(errJoin, fmt.Errorf("could not get audiences from token claims: %w", err)) + continue } if !hasAudience(tokenAudiences, expectedAudiences) { - return nil, errUnacceptableAud + errJoin = errors.Join(errJoin, errUnacceptableAud) + continue } } return pub, nil From c28dd2b16a97b06039259913808d9b8154d60e40 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Thu, 18 Sep 2025 15:35:35 +0530 Subject: [PATCH 22/45] fix: initial validation store unit test --- .../authentication/validation_store_test.go | 253 ++++++++++++++++++ 1 file changed, 253 insertions(+) create mode 100644 router/pkg/authentication/validation_store_test.go diff --git a/router/pkg/authentication/validation_store_test.go b/router/pkg/authentication/validation_store_test.go new file mode 100644 index 0000000000..6da5a95f87 --- /dev/null +++ b/router/pkg/authentication/validation_store_test.go @@ -0,0 +1,253 @@ +package authentication + +import ( + "context" + "crypto/ed25519" + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "testing" + + "github.com/MicahParks/jwkset" + requires "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestValidationStore(t *testing.T) { + // verify KeyWrite + t.Run("verify KeyWrite", func(t *testing.T) { + ctx := context.Background() + + t.Run("accepts supported algorithms without filter", func(t *testing.T) { + inner := jwkset.NewMemoryStorage() + store := NewValidationStore(nil, inner, nil) + keys := []jwkset.JWK{ + genRSAJWK(t, "rsa1", jwkset.AlgRS256), + genHMACJWK(t, "hmac1", jwkset.AlgHS256), + genEd25519JWK(t, "eddsa1"), + } + for _, k := range keys { + requires.NoError(t, store.KeyWrite(ctx, k)) + } + allInner, err := inner.KeyReadAll(ctx) + requires.NoError(t, err) + requires.Len(t, allInner, len(keys)) + }) + + t.Run("skips disallowed algorithms when filtered", func(t *testing.T) { + inner := jwkset.NewMemoryStorage() + store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + allowed := genRSAJWK(t, "rsa-allowed", jwkset.AlgRS256) + disallowed := genHMACJWK(t, "hmac-blocked", jwkset.AlgHS256) + requires.NoError(t, store.KeyWrite(ctx, allowed)) + requires.NoError(t, store.KeyWrite(ctx, disallowed)) // skipped, not error + all, err := inner.KeyReadAll(ctx) + requires.NoError(t, err) + requires.Len(t, all, 1) + requires.Equal(t, allowed.Marshal().KID, all[0].Marshal().KID) + }) + }) + + t.Run("verify KeyRead", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) + disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) + requires.NoError(t, inner.KeyWrite(ctx, allowed)) + requires.NoError(t, inner.KeyWrite(ctx, disallowed)) + store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + + _, err := store.KeyRead(ctx, allowed.Marshal().KID) + requires.NoError(t, err) + _, err = store.KeyRead(ctx, disallowed.Marshal().KID) + requires.ErrorContains(t, err, "unsupported algorithm") + }) + + t.Run("verify KeyReadAll", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) + disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) + requires.NoError(t, inner.KeyWrite(ctx, allowed)) + requires.NoError(t, inner.KeyWrite(ctx, disallowed)) + store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + + keys, err := store.KeyReadAll(ctx) + requires.NoError(t, err) + requires.Len(t, keys, 1) + m := keys[0].Marshal() + requires.Equal(t, allowed.Marshal().KID, m.KID) + }) + + t.Run("verify KeyReplaceAll", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + + allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) + disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) + + requires.NoError(t, store.KeyReplaceAll(ctx, []jwkset.JWK{allowed, disallowed})) + + all, err := inner.KeyReadAll(ctx) + requires.NoError(t, err) + requires.Len(t, all, 1) + requires.Equal(t, allowed.Marshal().KID, all[0].Marshal().KID) + }) + + t.Run("verify KeyDelete", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + store := NewValidationStore(zap.NewNop(), inner, nil) + + key := genRSAJWK(t, "rsa-del", jwkset.AlgRS256) + requires.NoError(t, store.KeyWrite(ctx, key)) + ok, err := store.KeyDelete(ctx, key.Marshal().KID) + requires.NoError(t, err) + requires.True(t, ok) + _, err = inner.KeyRead(ctx, key.Marshal().KID) + requires.ErrorContains(t, err, "not found") + }) + + t.Run("verify JSON", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) + disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) + requires.NoError(t, inner.KeyWrite(ctx, allowed)) + requires.NoError(t, inner.KeyWrite(ctx, disallowed)) + store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + + a, err := store.JSON(ctx) + requires.NoError(t, err) + b, err := inner.JSON(ctx) + requires.NoError(t, err) + requires.JSONEq(t, string(b), string(a)) + }) + + t.Run("verify JSONPublic", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) + disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) + requires.NoError(t, inner.KeyWrite(ctx, allowed)) + requires.NoError(t, inner.KeyWrite(ctx, disallowed)) + store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + + a, err := store.JSONPublic(ctx) + requires.NoError(t, err) + b, err := inner.JSONPublic(ctx) + requires.NoError(t, err) + requires.JSONEq(t, string(b), string(a)) + }) + + t.Run("verify JSONPrivate", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) + disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) + requires.NoError(t, inner.KeyWrite(ctx, allowed)) + requires.NoError(t, inner.KeyWrite(ctx, disallowed)) + store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + + a, err := store.JSONPrivate(ctx) + requires.NoError(t, err) + b, err := inner.JSONPrivate(ctx) + requires.NoError(t, err) + requires.JSONEq(t, string(b), string(a)) + }) + + t.Run("verify JSONWithOptions", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) + disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) + requires.NoError(t, inner.KeyWrite(ctx, allowed)) + requires.NoError(t, inner.KeyWrite(ctx, disallowed)) + store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + + a, err := store.JSONWithOptions(ctx, jwkset.JWKMarshalOptions{}, jwkset.JWKValidateOptions{}) + requires.NoError(t, err) + b, err := inner.JSONWithOptions(ctx, jwkset.JWKMarshalOptions{}, jwkset.JWKValidateOptions{}) + requires.NoError(t, err) + requires.JSONEq(t, string(b), string(a)) + }) + + t.Run("verify Marshal", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) + disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) + requires.NoError(t, inner.KeyWrite(ctx, allowed)) + requires.NoError(t, inner.KeyWrite(ctx, disallowed)) + store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + + ma, err := store.Marshal(ctx) + requires.NoError(t, err) + mb, err := inner.Marshal(ctx) + requires.NoError(t, err) + requires.Equal(t, len(mb.Keys), len(ma.Keys)) + }) + + t.Run("verify MarshalWithOptions", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) + disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) + requires.NoError(t, inner.KeyWrite(ctx, allowed)) + requires.NoError(t, inner.KeyWrite(ctx, disallowed)) + store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + + ma, err := store.MarshalWithOptions(ctx, jwkset.JWKMarshalOptions{}, jwkset.JWKValidateOptions{}) + requires.NoError(t, err) + mb, err := inner.MarshalWithOptions(ctx, jwkset.JWKMarshalOptions{}, jwkset.JWKValidateOptions{}) + requires.NoError(t, err) + requires.Equal(t, len(mb.Keys), len(ma.Keys)) + }) +} + +func genRSAJWK(t *testing.T, kid string, alg jwkset.ALG) jwkset.JWK { + t.Helper() + pk, err := rsa.GenerateKey(rand.Reader, 2048) + requires.NoError(t, err) + opts := jwkset.JWKOptions{ + Marshal: jwkset.JWKMarshalOptions{Private: false}, + Metadata: jwkset.JWKMetadataOptions{ALG: alg, KID: kid, USE: jwkset.UseSig}, + } + j, err := jwkset.NewJWKFromKey(pk, opts) + requires.NoError(t, err) + return j +} + +func genHMACJWK(t *testing.T, kid string, alg jwkset.ALG) jwkset.JWK { + t.Helper() + secret := make([]byte, 64) + _, err := rand.Read(secret) + requires.NoError(t, err) + // Use HMAC to derive a stable-length key material; any []byte works for JWK creation. + h := hmac.New(sha256.New, secret) + _, err = h.Write([]byte("test")) + requires.NoError(t, err) + key := h.Sum(nil) + opts := jwkset.JWKOptions{ + Marshal: jwkset.JWKMarshalOptions{Private: true}, + Metadata: jwkset.JWKMetadataOptions{ALG: alg, KID: kid, USE: jwkset.UseSig}, + } + j, err := jwkset.NewJWKFromKey(key, opts) + requires.NoError(t, err) + return j +} + +func genEd25519JWK(t *testing.T, kid string) jwkset.JWK { + t.Helper() + _, priv, err := ed25519.GenerateKey(rand.Reader) + requires.NoError(t, err) + opts := jwkset.JWKOptions{ + Marshal: jwkset.JWKMarshalOptions{Private: false}, + Metadata: jwkset.JWKMetadataOptions{ALG: jwkset.AlgEdDSA, KID: kid, USE: jwkset.UseSig}, + } + j, err := jwkset.NewJWKFromKey(priv, opts) + requires.NoError(t, err) + return j +} From c2e55b64ba2645ec10b56d65142952723c109150 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Thu, 18 Sep 2025 16:01:50 +0530 Subject: [PATCH 23/45] fix: compilation --- .../authentication/validation_store_test.go | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/router/pkg/authentication/validation_store_test.go b/router/pkg/authentication/validation_store_test.go index 6da5a95f87..7e13942c1c 100644 --- a/router/pkg/authentication/validation_store_test.go +++ b/router/pkg/authentication/validation_store_test.go @@ -21,7 +21,7 @@ func TestValidationStore(t *testing.T) { t.Run("accepts supported algorithms without filter", func(t *testing.T) { inner := jwkset.NewMemoryStorage() - store := NewValidationStore(nil, inner, nil) + store, _ := NewValidationStore(nil, inner, nil, false) keys := []jwkset.JWK{ genRSAJWK(t, "rsa1", jwkset.AlgRS256), genHMACJWK(t, "hmac1", jwkset.AlgHS256), @@ -37,7 +37,7 @@ func TestValidationStore(t *testing.T) { t.Run("skips disallowed algorithms when filtered", func(t *testing.T) { inner := jwkset.NewMemoryStorage() - store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) allowed := genRSAJWK(t, "rsa-allowed", jwkset.AlgRS256) disallowed := genHMACJWK(t, "hmac-blocked", jwkset.AlgHS256) requires.NoError(t, store.KeyWrite(ctx, allowed)) @@ -56,7 +56,7 @@ func TestValidationStore(t *testing.T) { disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) requires.NoError(t, inner.KeyWrite(ctx, allowed)) requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) _, err := store.KeyRead(ctx, allowed.Marshal().KID) requires.NoError(t, err) @@ -71,7 +71,7 @@ func TestValidationStore(t *testing.T) { disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) requires.NoError(t, inner.KeyWrite(ctx, allowed)) requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) keys, err := store.KeyReadAll(ctx) requires.NoError(t, err) @@ -83,7 +83,7 @@ func TestValidationStore(t *testing.T) { t.Run("verify KeyReplaceAll", func(t *testing.T) { ctx := context.Background() inner := jwkset.NewMemoryStorage() - store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) @@ -99,7 +99,7 @@ func TestValidationStore(t *testing.T) { t.Run("verify KeyDelete", func(t *testing.T) { ctx := context.Background() inner := jwkset.NewMemoryStorage() - store := NewValidationStore(zap.NewNop(), inner, nil) + store, _ := NewValidationStore(zap.NewNop(), inner, nil, false) key := genRSAJWK(t, "rsa-del", jwkset.AlgRS256) requires.NoError(t, store.KeyWrite(ctx, key)) @@ -117,7 +117,7 @@ func TestValidationStore(t *testing.T) { disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) requires.NoError(t, inner.KeyWrite(ctx, allowed)) requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) a, err := store.JSON(ctx) requires.NoError(t, err) @@ -133,7 +133,7 @@ func TestValidationStore(t *testing.T) { disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) requires.NoError(t, inner.KeyWrite(ctx, allowed)) requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) a, err := store.JSONPublic(ctx) requires.NoError(t, err) @@ -149,7 +149,7 @@ func TestValidationStore(t *testing.T) { disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) requires.NoError(t, inner.KeyWrite(ctx, allowed)) requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) a, err := store.JSONPrivate(ctx) requires.NoError(t, err) @@ -165,7 +165,7 @@ func TestValidationStore(t *testing.T) { disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) requires.NoError(t, inner.KeyWrite(ctx, allowed)) requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) a, err := store.JSONWithOptions(ctx, jwkset.JWKMarshalOptions{}, jwkset.JWKValidateOptions{}) requires.NoError(t, err) @@ -181,7 +181,7 @@ func TestValidationStore(t *testing.T) { disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) requires.NoError(t, inner.KeyWrite(ctx, allowed)) requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) ma, err := store.Marshal(ctx) requires.NoError(t, err) @@ -197,7 +197,7 @@ func TestValidationStore(t *testing.T) { disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) requires.NoError(t, inner.KeyWrite(ctx, allowed)) requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}) + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) ma, err := store.MarshalWithOptions(ctx, jwkset.JWKMarshalOptions{}, jwkset.JWKValidateOptions{}) requires.NoError(t, err) From c7ca05cbff4e4c60ed5e60d6efb787358bd0a620 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Thu, 18 Sep 2025 16:20:30 +0530 Subject: [PATCH 24/45] fix: update validation store unit tests --- .../authentication/validation_store_test.go | 165 +++++++++++++----- 1 file changed, 126 insertions(+), 39 deletions(-) diff --git a/router/pkg/authentication/validation_store_test.go b/router/pkg/authentication/validation_store_test.go index 7e13942c1c..120f93b97c 100644 --- a/router/pkg/authentication/validation_store_test.go +++ b/router/pkg/authentication/validation_store_test.go @@ -15,10 +15,7 @@ import ( ) func TestValidationStore(t *testing.T) { - // verify KeyWrite t.Run("verify KeyWrite", func(t *testing.T) { - ctx := context.Background() - t.Run("accepts supported algorithms without filter", func(t *testing.T) { inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(nil, inner, nil, false) @@ -27,6 +24,7 @@ func TestValidationStore(t *testing.T) { genHMACJWK(t, "hmac1", jwkset.AlgHS256), genEd25519JWK(t, "eddsa1"), } + ctx := context.Background() for _, k := range keys { requires.NoError(t, store.KeyWrite(ctx, k)) } @@ -40,6 +38,7 @@ func TestValidationStore(t *testing.T) { store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) allowed := genRSAJWK(t, "rsa-allowed", jwkset.AlgRS256) disallowed := genHMACJWK(t, "hmac-blocked", jwkset.AlgHS256) + ctx := context.Background() requires.NoError(t, store.KeyWrite(ctx, allowed)) requires.NoError(t, store.KeyWrite(ctx, disallowed)) // skipped, not error all, err := inner.KeyReadAll(ctx) @@ -47,53 +46,126 @@ func TestValidationStore(t *testing.T) { requires.Len(t, all, 1) requires.Equal(t, allowed.Marshal().KID, all[0].Marshal().KID) }) + + t.Run("accepts empty algorithm when allowed", func(t *testing.T) { + inner := jwkset.NewMemoryStorage() + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, true) + ctx := context.Background() + noAlg := genRSAJWK(t, "noalg-write-allowed", jwkset.ALG("")) + requires.NoError(t, store.KeyWrite(ctx, noAlg)) + all, err := inner.KeyReadAll(ctx) + requires.NoError(t, err) + requires.Len(t, all, 1) + requires.Equal(t, "noalg-write-allowed", all[0].Marshal().KID) + }) + + t.Run("skips empty algorithm when not allowed", func(t *testing.T) { + inner := jwkset.NewMemoryStorage() + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) + ctx := context.Background() + noAlg := genRSAJWK(t, "noalg-write-deny", jwkset.ALG("")) + requires.NoError(t, store.KeyWrite(ctx, noAlg)) + all, err := inner.KeyReadAll(ctx) + requires.NoError(t, err) + requires.Len(t, all, 0) + }) }) t.Run("verify KeyRead", func(t *testing.T) { - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) - disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) - requires.NoError(t, inner.KeyWrite(ctx, allowed)) - requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) + t.Run("allowed and disallowed algorithms", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) + disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) + requires.NoError(t, inner.KeyWrite(ctx, allowed)) + requires.NoError(t, inner.KeyWrite(ctx, disallowed)) + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - _, err := store.KeyRead(ctx, allowed.Marshal().KID) - requires.NoError(t, err) - _, err = store.KeyRead(ctx, disallowed.Marshal().KID) - requires.ErrorContains(t, err, "unsupported algorithm") + _, err := store.KeyRead(ctx, allowed.Marshal().KID) + requires.NoError(t, err) + _, err = store.KeyRead(ctx, disallowed.Marshal().KID) + requires.ErrorContains(t, err, "unsupported algorithm") + }) + + t.Run("empty algorithm allowed returns key", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + noAlg := genRSAJWK(t, "noalg-read-allowed", jwkset.ALG("")) + requires.NoError(t, inner.KeyWrite(ctx, noAlg)) + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, true) + + _, err := store.KeyRead(ctx, "noalg-read-allowed") + requires.NoError(t, err) + }) + + t.Run("empty algorithm not allowed returns error", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + noAlg := genRSAJWK(t, "noalg-read-deny", jwkset.ALG("")) + requires.NoError(t, inner.KeyWrite(ctx, noAlg)) + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) + + _, err := store.KeyRead(ctx, "noalg-read-deny") + requires.ErrorContains(t, err, "unsupported algorithm") + }) }) t.Run("verify KeyReadAll", func(t *testing.T) { - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) - disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) - requires.NoError(t, inner.KeyWrite(ctx, allowed)) - requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) + t.Run("filters to allowed algorithms", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) + disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) + requires.NoError(t, inner.KeyWrite(ctx, allowed)) + requires.NoError(t, inner.KeyWrite(ctx, disallowed)) + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - keys, err := store.KeyReadAll(ctx) - requires.NoError(t, err) - requires.Len(t, keys, 1) - m := keys[0].Marshal() - requires.Equal(t, allowed.Marshal().KID, m.KID) + keys, err := store.KeyReadAll(ctx) + requires.NoError(t, err) + requires.Len(t, keys, 1) + m := keys[0].Marshal() + requires.Equal(t, allowed.Marshal().KID, m.KID) + }) + + t.Run("includes empty algorithm when allowed", func(t *testing.T) { + ctx := context.Background() + inner2 := jwkset.NewMemoryStorage() + allowed2 := genRSAJWK(t, "rsa-2", jwkset.AlgRS256) + noAlg := genRSAJWK(t, "noalg-readall-allowed", jwkset.ALG("")) + requires.NoError(t, inner2.KeyWrite(ctx, allowed2)) + requires.NoError(t, inner2.KeyWrite(ctx, noAlg)) + store2, _ := NewValidationStore(zap.NewNop(), inner2, []string{"RS256"}, true) + keys2, err := store2.KeyReadAll(ctx) + requires.NoError(t, err) + requires.Len(t, keys2, 2) + }) }) t.Run("verify KeyReplaceAll", func(t *testing.T) { - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - - allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) - disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) - - requires.NoError(t, store.KeyReplaceAll(ctx, []jwkset.JWK{allowed, disallowed})) + t.Run("replaces with only allowed algorithms", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) + allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) + disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) + requires.NoError(t, store.KeyReplaceAll(ctx, []jwkset.JWK{allowed, disallowed})) + all, err := inner.KeyReadAll(ctx) + requires.NoError(t, err) + requires.Len(t, all, 1) + requires.Equal(t, allowed.Marshal().KID, all[0].Marshal().KID) + }) - all, err := inner.KeyReadAll(ctx) - requires.NoError(t, err) - requires.Len(t, all, 1) - requires.Equal(t, allowed.Marshal().KID, all[0].Marshal().KID) + t.Run("includes empty algorithm on replace when allowed", func(t *testing.T) { + ctx := context.Background() + inner := jwkset.NewMemoryStorage() + store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, true) + allowed := genRSAJWK(t, "rsa-3", jwkset.AlgRS256) + noAlg := genRSAJWK(t, "noalg-replace-allowed", jwkset.ALG("")) + requires.NoError(t, store.KeyReplaceAll(ctx, []jwkset.JWK{allowed, noAlg})) + all, err := inner.KeyReadAll(ctx) + requires.NoError(t, err) + requires.Len(t, all, 2) + }) }) t.Run("verify KeyDelete", func(t *testing.T) { @@ -205,15 +277,28 @@ func TestValidationStore(t *testing.T) { requires.NoError(t, err) requires.Equal(t, len(mb.Keys), len(ma.Keys)) }) + + // verify NewValidationStore supported algorithms return value + t.Run("verify ConstructorSupportedAlgorithms", func(t *testing.T) { + inner := jwkset.NewMemoryStorage() + _, algs := NewValidationStore(zap.NewNop(), inner, nil, false) + requires.Len(t, algs, len(supportedAlgorithms)) + _, algs2 := NewValidationStore(zap.NewNop(), inner, []string{"RS256", "INVALID"}, false) + requires.ElementsMatch(t, []string{"RS256"}, algs2) + }) } func genRSAJWK(t *testing.T, kid string, alg jwkset.ALG) jwkset.JWK { t.Helper() pk, err := rsa.GenerateKey(rand.Reader, 2048) requires.NoError(t, err) + meta := jwkset.JWKMetadataOptions{KID: kid, USE: jwkset.UseSig} + if alg.String() != "" { + meta.ALG = alg + } opts := jwkset.JWKOptions{ Marshal: jwkset.JWKMarshalOptions{Private: false}, - Metadata: jwkset.JWKMetadataOptions{ALG: alg, KID: kid, USE: jwkset.UseSig}, + Metadata: meta, } j, err := jwkset.NewJWKFromKey(pk, opts) requires.NoError(t, err) @@ -251,3 +336,5 @@ func genEd25519JWK(t *testing.T, kid string) jwkset.JWK { requires.NoError(t, err) return j } + +// genNoAlgRSAJWK creates a JWK with no ALG set to exercise allowEmptyAlgorithm behavior. From b81cd3c8cb4473886799d970a8bb1d06e8c394f1 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Thu, 18 Sep 2025 16:28:02 +0530 Subject: [PATCH 25/45] fix: test cleanup --- .../authentication/validation_store_test.go | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/router/pkg/authentication/validation_store_test.go b/router/pkg/authentication/validation_store_test.go index 120f93b97c..c2a588f912 100644 --- a/router/pkg/authentication/validation_store_test.go +++ b/router/pkg/authentication/validation_store_test.go @@ -15,8 +15,11 @@ import ( ) func TestValidationStore(t *testing.T) { + t.Parallel() t.Run("verify KeyWrite", func(t *testing.T) { + t.Parallel() t.Run("accepts supported algorithms without filter", func(t *testing.T) { + t.Parallel() inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(nil, inner, nil, false) keys := []jwkset.JWK{ @@ -34,6 +37,7 @@ func TestValidationStore(t *testing.T) { }) t.Run("skips disallowed algorithms when filtered", func(t *testing.T) { + t.Parallel() inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) allowed := genRSAJWK(t, "rsa-allowed", jwkset.AlgRS256) @@ -48,6 +52,7 @@ func TestValidationStore(t *testing.T) { }) t.Run("accepts empty algorithm when allowed", func(t *testing.T) { + t.Parallel() inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, true) ctx := context.Background() @@ -60,6 +65,7 @@ func TestValidationStore(t *testing.T) { }) t.Run("skips empty algorithm when not allowed", func(t *testing.T) { + t.Parallel() inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) ctx := context.Background() @@ -67,12 +73,14 @@ func TestValidationStore(t *testing.T) { requires.NoError(t, store.KeyWrite(ctx, noAlg)) all, err := inner.KeyReadAll(ctx) requires.NoError(t, err) - requires.Len(t, all, 0) + requires.Empty(t, all) }) }) t.Run("verify KeyRead", func(t *testing.T) { + t.Parallel() t.Run("allowed and disallowed algorithms", func(t *testing.T) { + t.Parallel() ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) @@ -88,6 +96,7 @@ func TestValidationStore(t *testing.T) { }) t.Run("empty algorithm allowed returns key", func(t *testing.T) { + t.Parallel() ctx := context.Background() inner := jwkset.NewMemoryStorage() noAlg := genRSAJWK(t, "noalg-read-allowed", jwkset.ALG("")) @@ -99,6 +108,7 @@ func TestValidationStore(t *testing.T) { }) t.Run("empty algorithm not allowed returns error", func(t *testing.T) { + t.Parallel() ctx := context.Background() inner := jwkset.NewMemoryStorage() noAlg := genRSAJWK(t, "noalg-read-deny", jwkset.ALG("")) @@ -111,7 +121,9 @@ func TestValidationStore(t *testing.T) { }) t.Run("verify KeyReadAll", func(t *testing.T) { + t.Parallel() t.Run("filters to allowed algorithms", func(t *testing.T) { + t.Parallel() ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) @@ -128,6 +140,7 @@ func TestValidationStore(t *testing.T) { }) t.Run("includes empty algorithm when allowed", func(t *testing.T) { + t.Parallel() ctx := context.Background() inner2 := jwkset.NewMemoryStorage() allowed2 := genRSAJWK(t, "rsa-2", jwkset.AlgRS256) @@ -142,7 +155,9 @@ func TestValidationStore(t *testing.T) { }) t.Run("verify KeyReplaceAll", func(t *testing.T) { + t.Parallel() t.Run("replaces with only allowed algorithms", func(t *testing.T) { + t.Parallel() ctx := context.Background() inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) @@ -156,6 +171,7 @@ func TestValidationStore(t *testing.T) { }) t.Run("includes empty algorithm on replace when allowed", func(t *testing.T) { + t.Parallel() ctx := context.Background() inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, true) @@ -169,6 +185,7 @@ func TestValidationStore(t *testing.T) { }) t.Run("verify KeyDelete", func(t *testing.T) { + t.Parallel() ctx := context.Background() inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(zap.NewNop(), inner, nil, false) @@ -183,6 +200,7 @@ func TestValidationStore(t *testing.T) { }) t.Run("verify JSON", func(t *testing.T) { + t.Parallel() ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) @@ -199,6 +217,7 @@ func TestValidationStore(t *testing.T) { }) t.Run("verify JSONPublic", func(t *testing.T) { + t.Parallel() ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) @@ -215,6 +234,7 @@ func TestValidationStore(t *testing.T) { }) t.Run("verify JSONPrivate", func(t *testing.T) { + t.Parallel() ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) @@ -231,6 +251,7 @@ func TestValidationStore(t *testing.T) { }) t.Run("verify JSONWithOptions", func(t *testing.T) { + t.Parallel() ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) @@ -247,6 +268,7 @@ func TestValidationStore(t *testing.T) { }) t.Run("verify Marshal", func(t *testing.T) { + t.Parallel() ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) @@ -259,10 +281,11 @@ func TestValidationStore(t *testing.T) { requires.NoError(t, err) mb, err := inner.Marshal(ctx) requires.NoError(t, err) - requires.Equal(t, len(mb.Keys), len(ma.Keys)) + requires.Len(t, ma.Keys, len(mb.Keys)) }) t.Run("verify MarshalWithOptions", func(t *testing.T) { + t.Parallel() ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) @@ -275,11 +298,12 @@ func TestValidationStore(t *testing.T) { requires.NoError(t, err) mb, err := inner.MarshalWithOptions(ctx, jwkset.JWKMarshalOptions{}, jwkset.JWKValidateOptions{}) requires.NoError(t, err) - requires.Equal(t, len(mb.Keys), len(ma.Keys)) + requires.Len(t, ma.Keys, len(mb.Keys)) }) // verify NewValidationStore supported algorithms return value t.Run("verify ConstructorSupportedAlgorithms", func(t *testing.T) { + t.Parallel() inner := jwkset.NewMemoryStorage() _, algs := NewValidationStore(zap.NewNop(), inner, nil, false) requires.Len(t, algs, len(supportedAlgorithms)) From 747e23a5dd1ae0085c925ca8ba071ff8d2b0341a Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Thu, 18 Sep 2025 16:34:28 +0530 Subject: [PATCH 26/45] fix: cleanup --- .../authentication/validation_store_test.go | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/router/pkg/authentication/validation_store_test.go b/router/pkg/authentication/validation_store_test.go index c2a588f912..3c4825e079 100644 --- a/router/pkg/authentication/validation_store_test.go +++ b/router/pkg/authentication/validation_store_test.go @@ -16,10 +16,13 @@ import ( func TestValidationStore(t *testing.T) { t.Parallel() + t.Run("verify KeyWrite", func(t *testing.T) { t.Parallel() + t.Run("accepts supported algorithms without filter", func(t *testing.T) { t.Parallel() + inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(nil, inner, nil, false) keys := []jwkset.JWK{ @@ -38,6 +41,7 @@ func TestValidationStore(t *testing.T) { t.Run("skips disallowed algorithms when filtered", func(t *testing.T) { t.Parallel() + inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) allowed := genRSAJWK(t, "rsa-allowed", jwkset.AlgRS256) @@ -53,6 +57,7 @@ func TestValidationStore(t *testing.T) { t.Run("accepts empty algorithm when allowed", func(t *testing.T) { t.Parallel() + inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, true) ctx := context.Background() @@ -66,6 +71,7 @@ func TestValidationStore(t *testing.T) { t.Run("skips empty algorithm when not allowed", func(t *testing.T) { t.Parallel() + inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) ctx := context.Background() @@ -79,8 +85,10 @@ func TestValidationStore(t *testing.T) { t.Run("verify KeyRead", func(t *testing.T) { t.Parallel() + t.Run("allowed and disallowed algorithms", func(t *testing.T) { t.Parallel() + ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) @@ -97,6 +105,7 @@ func TestValidationStore(t *testing.T) { t.Run("empty algorithm allowed returns key", func(t *testing.T) { t.Parallel() + ctx := context.Background() inner := jwkset.NewMemoryStorage() noAlg := genRSAJWK(t, "noalg-read-allowed", jwkset.ALG("")) @@ -109,6 +118,7 @@ func TestValidationStore(t *testing.T) { t.Run("empty algorithm not allowed returns error", func(t *testing.T) { t.Parallel() + ctx := context.Background() inner := jwkset.NewMemoryStorage() noAlg := genRSAJWK(t, "noalg-read-deny", jwkset.ALG("")) @@ -122,8 +132,10 @@ func TestValidationStore(t *testing.T) { t.Run("verify KeyReadAll", func(t *testing.T) { t.Parallel() + t.Run("filters to allowed algorithms", func(t *testing.T) { t.Parallel() + ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) @@ -141,6 +153,7 @@ func TestValidationStore(t *testing.T) { t.Run("includes empty algorithm when allowed", func(t *testing.T) { t.Parallel() + ctx := context.Background() inner2 := jwkset.NewMemoryStorage() allowed2 := genRSAJWK(t, "rsa-2", jwkset.AlgRS256) @@ -156,8 +169,10 @@ func TestValidationStore(t *testing.T) { t.Run("verify KeyReplaceAll", func(t *testing.T) { t.Parallel() + t.Run("replaces with only allowed algorithms", func(t *testing.T) { t.Parallel() + ctx := context.Background() inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) @@ -172,6 +187,7 @@ func TestValidationStore(t *testing.T) { t.Run("includes empty algorithm on replace when allowed", func(t *testing.T) { t.Parallel() + ctx := context.Background() inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, true) @@ -186,6 +202,7 @@ func TestValidationStore(t *testing.T) { t.Run("verify KeyDelete", func(t *testing.T) { t.Parallel() + ctx := context.Background() inner := jwkset.NewMemoryStorage() store, _ := NewValidationStore(zap.NewNop(), inner, nil, false) @@ -201,6 +218,7 @@ func TestValidationStore(t *testing.T) { t.Run("verify JSON", func(t *testing.T) { t.Parallel() + ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) @@ -218,6 +236,7 @@ func TestValidationStore(t *testing.T) { t.Run("verify JSONPublic", func(t *testing.T) { t.Parallel() + ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) @@ -235,6 +254,7 @@ func TestValidationStore(t *testing.T) { t.Run("verify JSONPrivate", func(t *testing.T) { t.Parallel() + ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) @@ -252,6 +272,7 @@ func TestValidationStore(t *testing.T) { t.Run("verify JSONWithOptions", func(t *testing.T) { t.Parallel() + ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) @@ -269,6 +290,7 @@ func TestValidationStore(t *testing.T) { t.Run("verify Marshal", func(t *testing.T) { t.Parallel() + ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) @@ -286,6 +308,7 @@ func TestValidationStore(t *testing.T) { t.Run("verify MarshalWithOptions", func(t *testing.T) { t.Parallel() + ctx := context.Background() inner := jwkset.NewMemoryStorage() allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) @@ -301,9 +324,9 @@ func TestValidationStore(t *testing.T) { requires.Len(t, ma.Keys, len(mb.Keys)) }) - // verify NewValidationStore supported algorithms return value t.Run("verify ConstructorSupportedAlgorithms", func(t *testing.T) { t.Parallel() + inner := jwkset.NewMemoryStorage() _, algs := NewValidationStore(zap.NewNop(), inner, nil, false) requires.Len(t, algs, len(supportedAlgorithms)) @@ -360,5 +383,3 @@ func genEd25519JWK(t *testing.T, kid string) jwkset.JWK { requires.NoError(t, err) return j } - -// genNoAlgRSAJWK creates a JWK with no ALG set to exercise allowEmptyAlgorithm behavior. From 482d572352b2d3a6776c44e8fda73a6d2a651756 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 15:55:10 +0530 Subject: [PATCH 27/45] fix: cleanup validation store --- router-tests/authentication_test.go | 15 +- router-tests/go.mod | 3 +- router-tests/go.sum | 6 +- router/go.mod | 3 +- router/go.sum | 6 +- .../pkg/authentication/jwks_token_decoder.go | 53 +-- router/pkg/authentication/keyfunc/keyfunc.go | 300 ++++++++++++++ .../authentication/keyfunc/keyfunc_test.go | 302 ++++++++++++++ router/pkg/authentication/validation_store.go | 173 -------- .../authentication/validation_store_test.go | 385 ------------------ router/pkg/config/config.go | 3 +- router/pkg/config/testdata/config_full.json | 9 +- 12 files changed, 632 insertions(+), 626 deletions(-) create mode 100644 router/pkg/authentication/keyfunc/keyfunc.go create mode 100644 router/pkg/authentication/keyfunc/keyfunc_test.go delete mode 100644 router/pkg/authentication/validation_store.go delete mode 100644 router/pkg/authentication/validation_store_test.go diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index f38bf054be..280130c87b 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -2228,7 +2228,6 @@ func TestSupportedAlgorithms(t *testing.T) { t.Parallel() body := testRequest(t, xEnv, authHeader(token), true) require.Equal(t, employeesExpectedData, string(body)) - }) t.Run("Should fail when providing no Token", func(t *testing.T) { @@ -2844,9 +2843,8 @@ func TestAudienceValidation(t *testing.T) { authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ { - URL: authServer.JWKSURL(), - RefreshInterval: time.Second * 5, - AllowEmptyAlgorithm: true, + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, }, }) @@ -2879,7 +2877,7 @@ func TestAudienceValidation(t *testing.T) { rsaCrypto, err := jwks.NewRSACrypto("", "", 2048) require.NoError(t, err) - + authServer, err := jwks.NewServerWithCrypto(t, rsaCrypto) require.NoError(t, err) t.Cleanup(authServer.Close) @@ -2888,10 +2886,9 @@ func TestAudienceValidation(t *testing.T) { authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ { - URL: authServer.JWKSURL(), - RefreshInterval: time.Second * 5, - AllowedAlgorithms: []string{string(allowedAlgorithm)}, - AllowEmptyAlgorithm: true, + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, + AllowedAlgorithms: []string{string(allowedAlgorithm)}, }, }) diff --git a/router-tests/go.mod b/router-tests/go.mod index 07a86f3f43..9d2d9bb3b9 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -3,7 +3,7 @@ module github.com/wundergraph/cosmo/router-tests go 1.25 require ( - github.com/MicahParks/jwkset v0.9.0 + github.com/MicahParks/jwkset v0.11.0 github.com/buger/jsonparser v1.1.1 github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 github.com/golang-jwt/jwt/v5 v5.2.2 @@ -45,7 +45,6 @@ require ( connectrpc.com/connect v1.16.2 // indirect github.com/99designs/gqlgen v0.17.76 // indirect github.com/KimMachineGun/automemlimit v0.6.1 // indirect - github.com/MicahParks/keyfunc/v3 v3.3.5 // indirect github.com/agnivade/levenshtein v1.2.1 // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect diff --git a/router-tests/go.sum b/router-tests/go.sum index 390d7e1971..39cb22207b 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -5,10 +5,8 @@ github.com/99designs/gqlgen v0.17.76/go.mod h1:miiU+PkAnTIDKMQ1BseUOIVeQHoiwYDZG github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/KimMachineGun/automemlimit v0.6.1 h1:ILa9j1onAAMadBsyyUJv5cack8Y1WT26yLj/V+ulKp8= github.com/KimMachineGun/automemlimit v0.6.1/go.mod h1:T7xYht7B8r6AG/AqFcUdc7fzd2bIdBKmepfP2S1svPY= -github.com/MicahParks/jwkset v0.9.0 h1:xDlGu6mZJdJ+mgAI4mIRqWm2p8Vrx0U98LMgRObw46M= -github.com/MicahParks/jwkset v0.9.0/go.mod h1:fVrj6TmG1aKlJEeceAz7JsXGTXEn72zP1px3us53JrA= -github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo= -github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8= +github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= +github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo= github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y= github.com/agnivade/levenshtein v1.2.1 h1:EHBY3UOn1gwdy/VbFwgo4cxecRznFk7fKWN1KOX7eoM= diff --git a/router/go.mod b/router/go.mod index ae51779a4f..00e4d36b3f 100644 --- a/router/go.mod +++ b/router/go.mod @@ -58,8 +58,7 @@ require ( require ( github.com/KimMachineGun/automemlimit v0.6.1 - github.com/MicahParks/jwkset v0.9.0 - github.com/MicahParks/keyfunc/v3 v3.3.5 + github.com/MicahParks/jwkset v0.11.0 github.com/alicebob/miniredis/v2 v2.34.0 github.com/caarlos0/env/v11 v11.3.1 github.com/cep21/circuit/v4 v4.0.0 diff --git a/router/go.sum b/router/go.sum index cdc81151aa..cbf8c86cc3 100644 --- a/router/go.sum +++ b/router/go.sum @@ -5,10 +5,8 @@ github.com/99designs/gqlgen v0.17.49/go.mod h1:tC8YFVZMed81x7UJ7ORUwXF4Kn6SXuucF github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/KimMachineGun/automemlimit v0.6.1 h1:ILa9j1onAAMadBsyyUJv5cack8Y1WT26yLj/V+ulKp8= github.com/KimMachineGun/automemlimit v0.6.1/go.mod h1:T7xYht7B8r6AG/AqFcUdc7fzd2bIdBKmepfP2S1svPY= -github.com/MicahParks/jwkset v0.9.0 h1:xDlGu6mZJdJ+mgAI4mIRqWm2p8Vrx0U98LMgRObw46M= -github.com/MicahParks/jwkset v0.9.0/go.mod h1:fVrj6TmG1aKlJEeceAz7JsXGTXEn72zP1px3us53JrA= -github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo= -github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8= +github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= +github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE= diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index b358978363..27f70cbc78 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -5,13 +5,13 @@ import ( "errors" "fmt" "net/http" - "slices" "time" + "github.com/wundergraph/cosmo/router/pkg/authentication/keyfunc" + "golang.org/x/time/rate" "github.com/MicahParks/jwkset" - "github.com/MicahParks/keyfunc/v3" "github.com/golang-jwt/jwt/v5" "github.com/wundergraph/cosmo/router/internal/httpclient" "go.uber.org/zap" @@ -69,9 +69,8 @@ type audKey struct { type audienceSet map[string]struct{} type keyFuncWithOpts struct { - keyFunc keyfunc.Keyfunc - allowEmptyAlgorithm bool - allowedAlgorithms []string + keyFunc keyfunc.Keyfunc + allowedAlgorithms []string } func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKSConfig) (TokenDecoder, error) { @@ -87,8 +86,6 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS l := logger.With(zap.String("url", c.URL)) - newValidationStore, processedAllowedAlgorithms := NewValidationStore(logger, nil, c.AllowedAlgorithms, c.AllowEmptyAlgorithm) - jwksetHTTPStorageOptions := jwkset.HTTPClientStorageOptions{ Client: newOIDCDiscoveryClient(httpclient.NewRetryableHTTPClient(l)), Ctx: ctx, // Used to end background refresh goroutine. @@ -99,7 +96,7 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS l.Error("Failed to refresh HTTP JWK Set from remote HTTP resource.", zap.Error(err)) }, RefreshInterval: c.RefreshInterval, - Storage: newValidationStore, + Storage: jwkset.NewMemoryStorage(), } store, err := jwkset.NewStorageFromHTTP(c.URL, jwksetHTTPStorageOptions) @@ -122,14 +119,13 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS jwksetHTTPClientOptions.RateLimitWaitMax = c.RefreshUnknownKID.MaxWait } - jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions) + jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions, c.AllowedAlgorithms) if err != nil { return nil, err } keyFuncMap[key] = keyFuncWithOpts{ - keyFunc: jwks, - allowEmptyAlgorithm: c.AllowEmptyAlgorithm, - allowedAlgorithms: processedAllowedAlgorithms, + keyFunc: jwks, + allowedAlgorithms: c.AllowedAlgorithms, } } else if c.Secret != "" { @@ -176,7 +172,7 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS PrioritizeHTTP: false, } - jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions) + jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions, make([]string, 0)) if err != nil { return nil, err } @@ -190,28 +186,6 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS keyFuncWrapper := jwt.Keyfunc(func(token *jwt.Token) (any, error) { var errJoin error for key, keyFuncAndOpts := range keyFuncMap { - // When an algorithm is actually provided in the jwks the current keyfunc will validate the - // jwts algorithm with it. But when no algorithm is provided (alg: none or missing alg) - // the default keyfunc will not validate the algorithm as it has nothing to cross check. - if keyFuncAndOpts.allowEmptyAlgorithm { - // We use the same error messages as keyfunc.Keyfunc for consistency - algInter, ok := token.Header["alg"] - if !ok { - errJoin = errors.Join(errJoin, fmt.Errorf("%w: could not find alg in JWT header", keyfunc.ErrKeyfunc)) - continue - } - alg, ok := algInter.(string) - if !ok { - errJoin = errors.Join(errJoin, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, keyfunc.ErrKeyfunc)) - continue - } - - // This is a custom validation different from keyfunc.Keyfunc - if !slices.Contains(keyFuncAndOpts.allowedAlgorithms, alg) { - errJoin = errors.Join(errJoin, fmt.Errorf("%w: could not find alg %s in allow list", keyfunc.ErrKeyfunc, alg)) - continue - } - } pub, err := keyFuncAndOpts.keyFunc.Keyfunc(token) if err != nil { @@ -250,16 +224,17 @@ func getAudienceSet(audiences []string) audienceSet { return audSet } -func createKeyFunc(ctx context.Context, options jwkset.HTTPClientOptions) (keyfunc.Keyfunc, error) { +func createKeyFunc(ctx context.Context, options jwkset.HTTPClientOptions, algorithms []string) (keyfunc.Keyfunc, error) { combined, err := jwkset.NewHTTPClient(options) if err != nil { return nil, fmt.Errorf("failed to create HTTP client storage for JWK provider: %w", err) } keyfuncOptions := keyfunc.Options{ - Ctx: ctx, - Storage: combined, - UseWhitelist: []jwkset.USE{jwkset.UseSig}, + Ctx: ctx, + Storage: combined, + UseWhitelist: []jwkset.USE{jwkset.UseSig}, + AllowedAlgorithms: algorithms, } jwks, err := keyfunc.New(keyfuncOptions) diff --git a/router/pkg/authentication/keyfunc/keyfunc.go b/router/pkg/authentication/keyfunc/keyfunc.go new file mode 100644 index 0000000000..1c37d34b53 --- /dev/null +++ b/router/pkg/authentication/keyfunc/keyfunc.go @@ -0,0 +1,300 @@ +// This is forked from https://github.com/MicahParks/keyfunc/blob/main/keyfunc.go +// Copyrights go to the original author. +package keyfunc + +import ( + "context" + "crypto" + "encoding/json" + "errors" + "fmt" + "log/slog" + "slices" + "time" + + "github.com/MicahParks/jwkset" + "github.com/golang-jwt/jwt/v5" + "golang.org/x/time/rate" +) + +var ( + // ErrKeyfunc is returned when a keyfunc error occurs. + ErrKeyfunc = errors.New("failed keyfunc") +) + +// Keyfunc is meant to be used as the jwt.Keyfunc function for github.com/golang-jwt/jwt/v5. It uses +// github.com/MicahParks/jwkset as a JWK Set storage. +type Keyfunc interface { + Keyfunc(token *jwt.Token) (any, error) + KeyfuncCtx(ctx context.Context) jwt.Keyfunc + Storage() jwkset.Storage + VerificationKeySet(ctx context.Context) (jwt.VerificationKeySet, error) +} + +// Options are used to create a new Keyfunc. +type Options struct { + Ctx context.Context + Storage jwkset.Storage + UseWhitelist []jwkset.USE + + // Custom Non Base on original keyfunc + AllowedAlgorithms []string +} + +// Override is used to change specific default behaviors. +type Override struct { + // HTTPTimeout is from https://pkg.go.dev/github.com/MicahParks/jwkset#HTTPClientStorageOptions + HTTPTimeout time.Duration + // RateLimitWaitMax is from https://pkg.go.dev/github.com/MicahParks/jwkset#HTTPClientOptions + RateLimitWaitMax time.Duration + // RefreshErrorHandlerFunc is a function that accepts the URL of the remote JWK Set storage and returns the + // RefreshErrorHandler from https://pkg.go.dev/github.com/MicahParks/jwkset#HTTPClientStorageOptions + RefreshErrorHandlerFunc func(u string) func(ctx context.Context, err error) + // RefreshInterval is from https://pkg.go.dev/github.com/MicahParks/jwkset#HTTPClientStorageOptions + RefreshInterval time.Duration + // RefreshUnknownKID is from https://pkg.go.dev/github.com/MicahParks/jwkset#HTTPClientOptions + RefreshUnknownKID *rate.Limiter + // ValidationSkipAll is copied to SkipAll in https://pkg.go.dev/github.com/MicahParks/jwkset#JWKValidateOptions + ValidationSkipAll bool +} + +type keyfunc struct { + ctx context.Context + storage jwkset.Storage + useWhitelist []jwkset.USE + allowedAlgorithms []string +} + +// New creates a new Keyfunc. +func New(options Options) (Keyfunc, error) { + ctx := options.Ctx + if ctx == nil { + ctx = context.Background() + } + if options.Storage == nil { + return nil, fmt.Errorf("%w: no JWK Set storage given in options", ErrKeyfunc) + } + k := keyfunc{ + ctx: ctx, + storage: options.Storage, + useWhitelist: options.UseWhitelist, + allowedAlgorithms: options.AllowedAlgorithms, + } + return k, nil +} + +// NewDefault creates a new Keyfunc with a default JWK Set storage and options. +// +// This will launch "refresh goroutine" to automatically refresh the remote HTTP resources. +func NewDefault(urls []string) (Keyfunc, error) { + return NewDefaultCtx(context.Background(), urls) +} + +// NewDefaultCtx creates a new Keyfunc with a default JWK Set storage and options. The context is used to end the +// "refresh goroutine". +// +// This will launch "refresh goroutine" to automatically refresh the remote HTTP resources. +func NewDefaultCtx(ctx context.Context, urls []string) (Keyfunc, error) { + client, err := jwkset.NewDefaultHTTPClientCtx(ctx, urls) + if err != nil { + return nil, err + } + options := Options{ + Storage: client, + } + return New(options) +} + +// NewDefaultOverrideCtx creates a new Keyfunc with a default JWK Set storage and options. The context is used to end +// the "refresh goroutine". The override parameter is used to change specific default behaviors. +// +// This will launch "refresh goroutine" to automatically refresh remote HTTP resources. +func NewDefaultOverrideCtx(ctx context.Context, urls []string, override Override) (Keyfunc, error) { + rateLimitWaitMax := time.Minute + if override.RateLimitWaitMax != 0 { + rateLimitWaitMax = override.RateLimitWaitMax + } + refreshErrorHandler := func(u string) func(ctx context.Context, err error) { + return func(ctx context.Context, err error) { + slog.Default().ErrorContext(ctx, "Failed to refresh HTTP JWK Set from remote HTTP resource.", + "error", err, + "url", u, + ) + } + } + if override.RefreshErrorHandlerFunc != nil { + refreshErrorHandler = override.RefreshErrorHandlerFunc + } + refreshInterval := time.Hour + if override.RefreshInterval > 0 { + refreshInterval = override.RefreshInterval + } + refreshUnknownKID := rate.NewLimiter(rate.Every(5*time.Minute), 1) + if override.RefreshUnknownKID != nil { + refreshUnknownKID = override.RefreshUnknownKID + } + + clientOptions := jwkset.HTTPClientOptions{ + HTTPURLs: make(map[string]jwkset.Storage), + RateLimitWaitMax: rateLimitWaitMax, + RefreshUnknownKID: refreshUnknownKID, + } + for _, u := range urls { + errorHandler := refreshErrorHandler(u) + options := jwkset.HTTPClientStorageOptions{ + Ctx: ctx, + NoErrorReturnFirstHTTPReq: true, + RefreshErrorHandler: errorHandler, + RefreshInterval: refreshInterval, + ValidateOptions: jwkset.JWKValidateOptions{ + SkipAll: override.ValidationSkipAll, + }, + } + + if override.HTTPTimeout > 0 { + options.HTTPTimeout = override.HTTPTimeout + } + + c, err := jwkset.NewStorageFromHTTP(u, options) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client storage for %q: %w", u, errors.Join(err, jwkset.ErrNewClient)) + } + clientOptions.HTTPURLs[u] = c + } + storage, err := jwkset.NewHTTPClient(clientOptions) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client storage: %w", errors.Join(err, jwkset.ErrNewClient)) + } + options := Options{ + Ctx: ctx, + Storage: storage, + UseWhitelist: nil, + } + return New(options) +} + +// NewJWKJSON creates a new Keyfunc from raw JWK JSON. +func NewJWKJSON(raw json.RawMessage) (Keyfunc, error) { + marshalOptions := jwkset.JWKMarshalOptions{ + Private: true, + } + jwk, err := jwkset.NewJWKFromRawJSON(raw, marshalOptions, jwkset.JWKValidateOptions{}) + if err != nil { + return nil, fmt.Errorf("%w: could not create JWK from raw JSON", errors.Join(err, ErrKeyfunc)) + } + store := jwkset.NewMemoryStorage() + err = store.KeyWrite(context.Background(), jwk) + if err != nil { + return nil, fmt.Errorf("%w: could not write JWK to storage", errors.Join(err, ErrKeyfunc)) + } + options := Options{ + Storage: store, + } + return New(options) +} + +// NewJWKSetJSON creates a new Keyfunc from raw JWK Set JSON. +func NewJWKSetJSON(raw json.RawMessage) (Keyfunc, error) { + var jwks jwkset.JWKSMarshal + err := json.Unmarshal(raw, &jwks) + if err != nil { + return nil, fmt.Errorf("%w: could not unmarshal raw JWK Set JSON", errors.Join(err, ErrKeyfunc)) + } + store, err := jwks.ToStorage() + if err != nil { + return nil, fmt.Errorf("%w: could not create JWK Set storage", errors.Join(err, ErrKeyfunc)) + } + options := Options{ + Storage: store, + } + return New(options) +} + +func (k keyfunc) KeyfuncCtx(ctx context.Context) jwt.Keyfunc { + return func(token *jwt.Token) (any, error) { + kidInter, ok := token.Header[jwkset.HeaderKID] + if !ok { + return k.VerificationKeySet(ctx) + } + kid, ok := kidInter.(string) + if !ok { + return nil, fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKeyfunc) + } + algInter, ok := token.Header["alg"] + if !ok { + return nil, fmt.Errorf("%w: could not find alg in JWT header", ErrKeyfunc) + } + alg, ok := algInter.(string) + if !ok { + // For test coverage purposes, this should be impossible to reach because the JWT package rejects a token + // without an alg parameter in the header before calling jwt.Keyfunc. + return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrKeyfunc) + } + + // When an algorithm is actually provided in the jwks the current keyfunc will validate the + // jwks algorithm with it. But when no algorithm is provided (alg: none or missing alg) + // the default keyfunc will not validate the algorithm as it has nothing to cross check. + if len(k.allowedAlgorithms) > 0 { + // This is a custom validation different from the original keyfunc.Keyfunc + if !slices.Contains(k.allowedAlgorithms, alg) { + return nil, fmt.Errorf("%w: could not find alg %s in allow list", ErrKeyfunc, alg) + } + } + + jwk, err := k.storage.KeyRead(ctx, kid) + if err != nil { + return nil, fmt.Errorf("%w: could not read JWK from storage", errors.Join(err, ErrKeyfunc)) + } + + if a := jwk.Marshal().ALG.String(); a != "" && a != alg { + return nil, fmt.Errorf(`%w: JWK "alg" parameter value %q does not match token "alg" parameter value %q`, ErrKeyfunc, a, alg) + } + if len(k.useWhitelist) > 0 { + found := false + for _, u := range k.useWhitelist { + if jwk.Marshal().USE == u { + found = true + break + } + } + if !found { + return nil, fmt.Errorf(`%w: JWK "use" parameter value %q is not in whitelist`, ErrKeyfunc, jwk.Marshal().USE) + } + } + + key := jwk.Key() + pk, ok := key.(publicKeyer) + if ok { + key = pk.Public() + } + + return key, nil + } +} +func (k keyfunc) Keyfunc(token *jwt.Token) (any, error) { + keyF := k.KeyfuncCtx(k.ctx) + return keyF(token) +} +func (k keyfunc) Storage() jwkset.Storage { + return k.storage +} +func (k keyfunc) VerificationKeySet(ctx context.Context) (jwt.VerificationKeySet, error) { + jwk, err := k.storage.KeyReadAll(ctx) + if err != nil { + return jwt.VerificationKeySet{}, fmt.Errorf("failed to read all JWK from storage: %w", errors.Join(err, ErrKeyfunc)) + } + var allKeys jwt.VerificationKeySet + for _, j := range jwk { + key := j.Key() + pk, ok := key.(publicKeyer) + if ok { + key = pk.Public() + } + allKeys.Keys = append(allKeys.Keys, key) + } + return allKeys, nil +} + +type publicKeyer interface { + Public() crypto.PublicKey +} diff --git a/router/pkg/authentication/keyfunc/keyfunc_test.go b/router/pkg/authentication/keyfunc/keyfunc_test.go new file mode 100644 index 0000000000..1b2b28843f --- /dev/null +++ b/router/pkg/authentication/keyfunc/keyfunc_test.go @@ -0,0 +1,302 @@ +// This is forked from https://github.com/MicahParks/keyfunc/blob/main/keyfunc.go +// Copyrights go to the original author. +package keyfunc + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/MicahParks/jwkset" + "github.com/golang-jwt/jwt/v5" +) + +const ( + keyID = "my-key-id" +) + +func TestNew(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("Failed to generate ED25519 key pair. Error: %s", err) + } + jwk, err := jwkset.NewJWKFromKey(priv, jwkset.JWKOptions{}) + if err != nil { + t.Fatalf("Failed to create JWK from ED25519 private key. Error: %s", err) + } + + serverStore := jwkset.NewMemoryStorage() + err = serverStore.KeyWrite(ctx, jwk) + if err != nil { + t.Fatalf("Failed to write ED25519 public key to server store. Error: %s", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rawJWKS, err := serverStore.JSONPrivate(ctx) + if err != nil { + t.Fatalf("Failed to get JWK Set JSON from server store. Error: %s", err) + } + _, _ = w.Write(rawJWKS) + })) + defer server.Close() + + token := jwt.New(jwt.SigningMethodEdDSA) + token.Header[jwkset.HeaderKID] = keyID + signed, err := token.SignedString(priv) + if err != nil { + t.Fatalf("Failed to sign JWT. Error: %s", err) + } + + clientStore, err := jwkset.NewDefaultHTTPClient([]string{server.URL}) + if err != nil { + t.Fatalf("Failed to create client store. Error: %s", err) + } + options := Options{ + Ctx: ctx, + Storage: clientStore, + UseWhitelist: []jwkset.USE{jwkset.UseSig}, + } + k, err := New(options) + if err != nil { + t.Fatalf("Failed to create keyfunc. Error: %s", err) + } + + _, err = jwt.Parse(signed, k.Keyfunc) + if !errors.Is(err, ErrKeyfunc) { + t.Fatalf("Expected ErrKeyfunc for missing Key ID in header, but got %s.", err) + } + + metadata := jwkset.JWKMetadataOptions{ + KID: keyID, + USE: jwkset.UseSig, + } + jwkOptions := jwkset.JWKOptions{ + Metadata: metadata, + } + jwk, err = jwkset.NewJWKFromKey(priv, jwkOptions) + if err != nil { + t.Fatalf("Failed to create JWK from ED25519 private key. Error: %s", err) + } + err = serverStore.KeyWrite(ctx, jwk) + if err != nil { + t.Fatalf("Failed to write ED25519 public key to server store. Error: %s", err) + } + + clientStore, err = jwkset.NewDefaultHTTPClient([]string{server.URL}) + if err != nil { + t.Fatalf("Failed to create client store. Error: %s", err) + } + options.Storage = clientStore + k, err = New(options) + if err != nil { + t.Fatalf("Failed to create keyfunc. Error: %s", err) + } + + _, err = jwt.Parse(signed, k.Keyfunc) + if err != nil { + t.Fatalf("Failed to parse JWT. Error: %s", err) + } + + if !reflect.DeepEqual(k.Storage(), clientStore) { + t.Fatalf("Expected client store, but got something else.") + } + + _, err = NewDefault([]string{server.URL}) + if err != nil { + t.Fatalf("Failed to create keyfunc. Error: %s", err) + } + + _, err = NewDefaultOverrideCtx(ctx, []string{server.URL}, Override{}) + if err != nil { + t.Fatalf("Failed to create keyfunc with overrides. Error: %s", err) + } +} + +func TestNewErr(t *testing.T) { + _, err := New(Options{}) + if !errors.Is(err, ErrKeyfunc) { + t.Error("Expected ErrKeyfunc, but got nil.") + } +} + +func TestNewJWKJSON(t *testing.T) { + // Get the JWK as JSON. + jwksJSON := json.RawMessage(`{"kty": "RSA","e": "AQAB","kid": "ee8d626d","n": "gRda5b0pkgTytDuLrRnNSYhvfMIyM0ASq2ZggY4dVe12JV8N7lyXilyqLKleD-2lziivvzE8O8CdIC2vUf0tBD7VuMyldnZruSEZWCuKJPdgKgy9yPpShmD2NyhbwQIAbievGMJIp_JMwz8MkdY5pzhPECGNgCEtUAmsrrctP5V8HuxaxGt9bb-DdPXkYWXW3MPMSlVpGZ5GiIeTABxqYNG2MSoYeQ9x8O3y488jbassTqxExI_4w9MBQBJR9HIXjWrrrenCcDlMY71rzkbdj3mmcn9xMq2vB5OhfHyHTihbUPLSm83aFWSuW9lE7ogMc93XnrB8evIAk6VfsYlS9Q"}`) + + // Create the keyfunc.Keyfunc. + k, err := NewJWKJSON(jwksJSON) + if err != nil { + t.Fatalf("Failed to create a keyfunc.Keyfunc.\nError: %s", err) + } + + // Get a JWT to parse. + jwtB64 := "eyJraWQiOiJlZThkNjI2ZCIsInR5cCI6IkpXVCIsImFsZyI6IlJTMjU2In0.eyJzdWIiOiJXZWlkb25nIiwiYXVkIjoiVGFzaHVhbiIsImlzcyI6Imp3a3Mtc2VydmljZS5hcHBzcG90LmNvbSIsImlhdCI6MTYzMTM2OTk1NSwianRpIjoiNDY2M2E5MTAtZWU2MC00NzcwLTgxNjktY2I3NDdiMDljZjU0In0.LwD65d5h6U_2Xco81EClMa_1WIW4xXZl8o4b7WzY_7OgPD2tNlByxvGDzP7bKYA9Gj--1mi4Q4li4CAnKJkaHRYB17baC0H5P9lKMPuA6AnChTzLafY6yf-YadA7DmakCtIl7FNcFQQL2DXmh6gS9J6TluFoCIXj83MqETbDWpL28o3XAD_05UP8VLQzH2XzyqWKi97mOuvz-GsDp9mhBYQUgN3csNXt2v2l-bUPWe19SftNej0cxddyGu06tXUtaS6K0oe0TTbaqc3hmfEiu5G0J8U6ztTUMwXkBvaknE640NPgMQJqBaey0E4u0txYgyvMvvxfwtcOrDRYqYPBnA" + + // Parse the JWT. + token, err := jwt.Parse(jwtB64, k.Keyfunc) + if err != nil { + t.Fatalf("Failed to parse the JWT.\nError: %s", err) + } + + // Check if the token is valid. + if !token.Valid { + t.Fatalf("The token is not valid.") + } +} + +func TestNewJWKSetJSON(t *testing.T) { + // Get the JWK Set as JSON. + jwksJSON := json.RawMessage(`{"keys":[{"kty":"RSA","e":"AQAB","kid":"ee8d626d","n":"gRda5b0pkgTytDuLrRnNSYhvfMIyM0ASq2ZggY4dVe12JV8N7lyXilyqLKleD-2lziivvzE8O8CdIC2vUf0tBD7VuMyldnZruSEZWCuKJPdgKgy9yPpShmD2NyhbwQIAbievGMJIp_JMwz8MkdY5pzhPECGNgCEtUAmsrrctP5V8HuxaxGt9bb-DdPXkYWXW3MPMSlVpGZ5GiIeTABxqYNG2MSoYeQ9x8O3y488jbassTqxExI_4w9MBQBJR9HIXjWrrrenCcDlMY71rzkbdj3mmcn9xMq2vB5OhfHyHTihbUPLSm83aFWSuW9lE7ogMc93XnrB8evIAk6VfsYlS9Q"},{"kty":"EC","crv":"P-256","kid":"711d48d1","x":"tfXCoBU-wXemeQCkME1gMZWK0-UECCHIkedASZR0t-Q","y":"9xzYtnKQdiQJHCtGwpZWF21eP1fy5x4wC822rCilmBw"},{"kty":"EC","crv":"P-384","kid":"d52c9829","x":"tFx6ev6eLs9sNfdyndn4OgbhV6gPFVn7Ul0VD5vwuplJLbIYeFLI6T42tTaE5_Q4","y":"A0gzB8TqxPX7xMzyHH_FXkYG2iROANH_kQxBovSeus6l_QSyqYlipWpBy9BhY9dz"},{"kty":"RSA","e":"AQAB","kid":"ecac72e5","n":"nLbnTvZAUxdmuAbDDUNAfha6mw0fri3UpV2w1PxilflBuSnXJhzo532-YQITogoanMjy_sQ8kHUhZYHVRR6vLZRBBbl-hP8XWiCe4wwioy7Ey3TiIUYfW-SD6I42XbLt5o-47IR0j5YDXxnX2UU7-UgR_kITBeLDfk0rSp4B0GUhPbP5IDItS0MHHDDS3lhvJomxgEfoNrp0K0Fz_s0K33hfOqc2hD1tSkX-3oDTQVRMF4Nxax3NNw8-ahw6HNMlXlwWfXodgRMvj9pcz8xUYa3C5IlPlZkMumeNCFx1qds6K_eYcU0ss91DdbhhE8amRX1FsnBJNMRUkA5i45xkOIx15rQN230zzh0p71jvtx7wYRr5pdMlwxV0T9Ck5PCmx-GzFazA2X6DJ0Xnn1-cXkRoZHFj_8Mba1dUrNz-NWEk83uW5KT-ZEbX7nzGXtayKWmGb873a8aYPqIsp6bQ_-eRBd8TDT2g9HuPyPr5VKa1p33xKaohz4DGy3t1Qpy3UWnbPXUlh5dLWPKz-TcS9FP5gFhWVo-ZhU03Pn6P34OxHmXGWyQao18dQGqzgD4e9vY3rLhfcjVZJYNlWY2InsNwbYS-DnienPf1ws-miLeXxNKG3tFydoQzHwyOxG6Wc-HBfzL_hOvxINKQamvPasaYWl1LWznMps6elKCgKDc"},{"kty":"EC","crv":"P-521","kid":"c570888f","x":"AHNpXq0J7rikNRlwhaMYDD8LGVAVJzNJ-jEPksUIn2LB2LCdNRzfAhgbxdQcWT9ktlc9M1EhmTLccEqfnWdGL9G1","y":"AfHPUW3GYzzqbTczcYR0nYMVMFVrYsUxv4uiuSNV_XRN3Jf8zeYbbOLJv4S3bUytO7qHY8bfZxPxR9nn3BBTf5ol"}]}`) + + // Create the keyfunc.Keyfunc. + k, err := NewJWKSetJSON(jwksJSON) + if err != nil { + t.Fatalf("Failed to create a keyfunc.Keyfunc.\nError: %s", err) + } + + // Get a JWT to parse. + jwtB64 := "eyJraWQiOiJlZThkNjI2ZCIsInR5cCI6IkpXVCIsImFsZyI6IlJTMjU2In0.eyJzdWIiOiJXZWlkb25nIiwiYXVkIjoiVGFzaHVhbiIsImlzcyI6Imp3a3Mtc2VydmljZS5hcHBzcG90LmNvbSIsImlhdCI6MTYzMTM2OTk1NSwianRpIjoiNDY2M2E5MTAtZWU2MC00NzcwLTgxNjktY2I3NDdiMDljZjU0In0.LwD65d5h6U_2Xco81EClMa_1WIW4xXZl8o4b7WzY_7OgPD2tNlByxvGDzP7bKYA9Gj--1mi4Q4li4CAnKJkaHRYB17baC0H5P9lKMPuA6AnChTzLafY6yf-YadA7DmakCtIl7FNcFQQL2DXmh6gS9J6TluFoCIXj83MqETbDWpL28o3XAD_05UP8VLQzH2XzyqWKi97mOuvz-GsDp9mhBYQUgN3csNXt2v2l-bUPWe19SftNej0cxddyGu06tXUtaS6K0oe0TTbaqc3hmfEiu5G0J8U6ztTUMwXkBvaknE640NPgMQJqBaey0E4u0txYgyvMvvxfwtcOrDRYqYPBnA" + + // Parse the JWT. + token, err := jwt.Parse(jwtB64, k.Keyfunc) + if err != nil { + t.Fatalf("Failed to parse the JWT.\nError: %s", err) + } + + // Check if the token is valid. + if !token.Valid { + t.Fatalf("The token is not valid.") + } +} + +func TestVerificationKeySet(t *testing.T) { + ctx := context.Background() + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("Failed to generate ED25519 key pair: %v", err) + } + jwk, err := jwkset.NewJWKFromKey(priv, jwkset.JWKOptions{}) + if err != nil { + t.Fatalf("Failed to create JWK: %v", err) + } + store := jwkset.NewMemoryStorage() + err = store.KeyWrite(ctx, jwk) + if err != nil { + t.Fatalf("Failed to write JWK: %v", err) + } + k, err := New(Options{Ctx: ctx, Storage: store}) + if err != nil { + t.Fatalf("Failed to create Keyfunc: %v", err) + } + vks, err := k.VerificationKeySet(ctx) + if err != nil { + t.Fatalf("VerificationKeySet failed: %v", err) + } + if len(vks.Keys) != 1 { + t.Fatalf("Expected 1 key, got %d", len(vks.Keys)) + } +} + +func TestNoKIDHeaderCallsVerificationKeySet(t *testing.T) { + ctx := context.Background() + + // Generate two key pairs. + _, priv1, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("Failed to generate ED25519 key pair 1: %v", err) + } + _, priv2, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("Failed to generate ED25519 key pair 2: %v", err) + } + + jwk1, err := jwkset.NewJWKFromKey(priv1, jwkset.JWKOptions{}) + if err != nil { + t.Fatalf("Failed to create JWK 1: %v", err) + } + jwk2, err := jwkset.NewJWKFromKey(priv2, jwkset.JWKOptions{}) + if err != nil { + t.Fatalf("Failed to create JWK 2: %v", err) + } + + orders := [][]jwkset.JWK{ + {jwk1, jwk2}, + {jwk2, jwk1}, + } + privs := []ed25519.PrivateKey{priv1, priv2} + + for i, order := range orders { + store := jwkset.NewMemoryStorage() + for _, jwk := range order { + err = store.KeyWrite(ctx, jwk) + if err != nil { + t.Fatalf("Failed to write JWK: %v", err) + } + } + k, err := New(Options{Ctx: ctx, Storage: store}) + if err != nil { + t.Fatalf("Failed to create Keyfunc: %v", err) + } + // Sign a token with the corresponding private key (no KID header) + token := jwt.New(jwt.SigningMethodEdDSA) + tokenString, err := token.SignedString(privs[i]) + if err != nil { + t.Fatalf("Failed to sign token: %v", err) + } + parsedToken, err := jwt.Parse(tokenString, k.KeyfuncCtx(ctx)) + if err != nil { + t.Fatalf("Parse failed (order %d): %v", i+1, err) + } + if !parsedToken.Valid { + t.Fatalf("Expected token to be valid (order %d)", i+1) + } + } +} + +func TestNoKIDHeaderNoMatchingJWK(t *testing.T) { + ctx := context.Background() + + _, missingFromSet, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("Failed to generate ED25519 key pair: %v", err) + } + + _, presentInSet, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatalf("Failed to generate other ED25519 key pair: %v", err) + } + jwk, err := jwkset.NewJWKFromKey(presentInSet, jwkset.JWKOptions{}) + if err != nil { + t.Fatalf("Failed to create JWK: %v", err) + } + store := jwkset.NewMemoryStorage() + err = store.KeyWrite(ctx, jwk) + if err != nil { + t.Fatalf("Failed to write JWK: %v", err) + } + + k, err := New(Options{Ctx: ctx, Storage: store}) + if err != nil { + t.Fatalf("Failed to create Keyfunc: %v", err) + } + + token := jwt.New(jwt.SigningMethodEdDSA) + tokenString, err := token.SignedString(missingFromSet) + if err != nil { + t.Fatalf("Failed to sign token: %v", err) + } + + _, err = jwt.Parse(tokenString, k.KeyfuncCtx(ctx)) + if err == nil { + t.Fatalf("Expected error due to no matching JWK, but got none") + } +} diff --git a/router/pkg/authentication/validation_store.go b/router/pkg/authentication/validation_store.go deleted file mode 100644 index 96f174b246..0000000000 --- a/router/pkg/authentication/validation_store.go +++ /dev/null @@ -1,173 +0,0 @@ -package authentication - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/MicahParks/jwkset" - "go.uber.org/zap" -) - -var _ jwkset.Storage = (*validationStore)(nil) - -type validationStore struct { - logger *zap.Logger - algs map[string]struct{} - inner jwkset.Storage - allowEmptyAlgorithm bool -} - -var supportedAlgorithms = map[string]struct{}{ - "HS256": {}, - "HS384": {}, - "HS512": {}, - "RS256": {}, - "RS384": {}, - "RS512": {}, - "PS256": {}, - "PS384": {}, - "PS512": {}, - "ES256": {}, - "ES384": {}, - "ES512": {}, - "EdDSA": {}, -} - -func NewValidationStore(logger *zap.Logger, inner jwkset.Storage, algs []string, allowEmptyAlgorithm bool) (jwkset.Storage, []string) { - if inner == nil { - inner = jwkset.NewMemoryStorage() - } - - if logger == nil { - logger = zap.NewNop() - } - - algSet := make(map[string]struct{}, len(algs)) - - store := &validationStore{ - logger: logger, - inner: inner, - algs: supportedAlgorithms, - allowEmptyAlgorithm: allowEmptyAlgorithm, - } - - if len(algs) == 0 { - return store, store.getSupportedAlgorithms() - } - - for _, alg := range algs { - if _, ok := supportedAlgorithms[alg]; !ok { - logger.Warn("Unsupported algorithm", zap.String("algorithm", alg)) - continue - } - algSet[alg] = struct{}{} - } - - store.algs = algSet - return store, store.getSupportedAlgorithms() -} - -func (v *validationStore) getSupportedAlgorithms() []string { - algs := make([]string, 0, len(v.algs)) - for alg := range v.algs { - algs = append(algs, alg) - } - return algs -} - -func (v *validationStore) KeyDelete(ctx context.Context, keyID string) (ok bool, err error) { - return v.inner.KeyDelete(ctx, keyID) -} - -func (v *validationStore) KeyRead(ctx context.Context, keyID string) (jwkset.JWK, error) { - key, err := v.inner.KeyRead(ctx, keyID) - if err != nil { - return key, err - } - - if fKey, ok := v.getFilteredKey(key); ok { - return fKey, nil - } - - return jwkset.JWK{}, fmt.Errorf("key with ID %q has an unsupported algorithm %s", keyID, key.Marshal().ALG.String()) -} - -func (v *validationStore) KeyReadAll(ctx context.Context) ([]jwkset.JWK, error) { - keys, err := v.inner.KeyReadAll(ctx) - if err != nil { - return nil, err - } - - filter := make([]jwkset.JWK, 0, len(keys)) - - for _, k := range keys { - if fKey, ok := v.getFilteredKey(k); ok { - filter = append(filter, fKey) - } - } - - return filter, nil -} - -func (v *validationStore) KeyReplaceAll(ctx context.Context, given []jwkset.JWK) error { - filtered := make([]jwkset.JWK, 0) - for _, k := range given { - if fKey, ok := v.getFilteredKey(k); ok { - filtered = append(filtered, fKey) - } - } - return v.inner.KeyReplaceAll(ctx, filtered) -} - -func (v *validationStore) KeyWrite(ctx context.Context, jwk jwkset.JWK) error { - if _, ok := v.getFilteredKey(jwk); !ok { - // We should not return an error here. If JWKS are configured for multiple applications, we should only add the - // supported keys to the token decoder store and not prevent the refresh entirely. - // In case we are receiving a key with an unsupported algorithm we log a warning instead. - jwkMarshal := jwk.Marshal() - v.logger.Warn("Skipping key with unsupported algorithm", zap.String("keyID", jwkMarshal.KID), zap.String("algorithm", jwkMarshal.ALG.String())) - return nil - } - - return v.inner.KeyWrite(ctx, jwk) -} - -func (v *validationStore) JSON(ctx context.Context) (json.RawMessage, error) { - return v.inner.JSON(ctx) -} - -func (v *validationStore) JSONPublic(ctx context.Context) (json.RawMessage, error) { - return v.inner.JSONPublic(ctx) -} - -func (v *validationStore) JSONPrivate(ctx context.Context) (json.RawMessage, error) { - return v.inner.JSONPrivate(ctx) -} - -func (v *validationStore) JSONWithOptions(ctx context.Context, marshalOptions jwkset.JWKMarshalOptions, validationOptions jwkset.JWKValidateOptions) (json.RawMessage, error) { - return v.inner.JSONWithOptions(ctx, marshalOptions, validationOptions) -} - -func (v *validationStore) Marshal(ctx context.Context) (jwkset.JWKSMarshal, error) { - return v.inner.Marshal(ctx) -} - -func (v *validationStore) MarshalWithOptions(ctx context.Context, marshalOptions jwkset.JWKMarshalOptions, validationOptions jwkset.JWKValidateOptions) (jwkset.JWKSMarshal, error) { - return v.inner.MarshalWithOptions(ctx, marshalOptions, validationOptions) -} - -func (v *validationStore) getFilteredKey(k jwkset.JWK) (jwkset.JWK, bool) { - algString := k.Marshal().ALG.String() - - // If we allow empty algorithm, we accept JWK without an algorithm - // This is algorithm is actually optional according to the RFC - if algString == "" && v.allowEmptyAlgorithm { - return k, true - } - if _, ok := v.algs[algString]; ok { - return k, true - } - - return jwkset.JWK{}, false -} diff --git a/router/pkg/authentication/validation_store_test.go b/router/pkg/authentication/validation_store_test.go deleted file mode 100644 index 3c4825e079..0000000000 --- a/router/pkg/authentication/validation_store_test.go +++ /dev/null @@ -1,385 +0,0 @@ -package authentication - -import ( - "context" - "crypto/ed25519" - "crypto/hmac" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "testing" - - "github.com/MicahParks/jwkset" - requires "github.com/stretchr/testify/require" - "go.uber.org/zap" -) - -func TestValidationStore(t *testing.T) { - t.Parallel() - - t.Run("verify KeyWrite", func(t *testing.T) { - t.Parallel() - - t.Run("accepts supported algorithms without filter", func(t *testing.T) { - t.Parallel() - - inner := jwkset.NewMemoryStorage() - store, _ := NewValidationStore(nil, inner, nil, false) - keys := []jwkset.JWK{ - genRSAJWK(t, "rsa1", jwkset.AlgRS256), - genHMACJWK(t, "hmac1", jwkset.AlgHS256), - genEd25519JWK(t, "eddsa1"), - } - ctx := context.Background() - for _, k := range keys { - requires.NoError(t, store.KeyWrite(ctx, k)) - } - allInner, err := inner.KeyReadAll(ctx) - requires.NoError(t, err) - requires.Len(t, allInner, len(keys)) - }) - - t.Run("skips disallowed algorithms when filtered", func(t *testing.T) { - t.Parallel() - - inner := jwkset.NewMemoryStorage() - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - allowed := genRSAJWK(t, "rsa-allowed", jwkset.AlgRS256) - disallowed := genHMACJWK(t, "hmac-blocked", jwkset.AlgHS256) - ctx := context.Background() - requires.NoError(t, store.KeyWrite(ctx, allowed)) - requires.NoError(t, store.KeyWrite(ctx, disallowed)) // skipped, not error - all, err := inner.KeyReadAll(ctx) - requires.NoError(t, err) - requires.Len(t, all, 1) - requires.Equal(t, allowed.Marshal().KID, all[0].Marshal().KID) - }) - - t.Run("accepts empty algorithm when allowed", func(t *testing.T) { - t.Parallel() - - inner := jwkset.NewMemoryStorage() - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, true) - ctx := context.Background() - noAlg := genRSAJWK(t, "noalg-write-allowed", jwkset.ALG("")) - requires.NoError(t, store.KeyWrite(ctx, noAlg)) - all, err := inner.KeyReadAll(ctx) - requires.NoError(t, err) - requires.Len(t, all, 1) - requires.Equal(t, "noalg-write-allowed", all[0].Marshal().KID) - }) - - t.Run("skips empty algorithm when not allowed", func(t *testing.T) { - t.Parallel() - - inner := jwkset.NewMemoryStorage() - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - ctx := context.Background() - noAlg := genRSAJWK(t, "noalg-write-deny", jwkset.ALG("")) - requires.NoError(t, store.KeyWrite(ctx, noAlg)) - all, err := inner.KeyReadAll(ctx) - requires.NoError(t, err) - requires.Empty(t, all) - }) - }) - - t.Run("verify KeyRead", func(t *testing.T) { - t.Parallel() - - t.Run("allowed and disallowed algorithms", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) - disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) - requires.NoError(t, inner.KeyWrite(ctx, allowed)) - requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - - _, err := store.KeyRead(ctx, allowed.Marshal().KID) - requires.NoError(t, err) - _, err = store.KeyRead(ctx, disallowed.Marshal().KID) - requires.ErrorContains(t, err, "unsupported algorithm") - }) - - t.Run("empty algorithm allowed returns key", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - noAlg := genRSAJWK(t, "noalg-read-allowed", jwkset.ALG("")) - requires.NoError(t, inner.KeyWrite(ctx, noAlg)) - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, true) - - _, err := store.KeyRead(ctx, "noalg-read-allowed") - requires.NoError(t, err) - }) - - t.Run("empty algorithm not allowed returns error", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - noAlg := genRSAJWK(t, "noalg-read-deny", jwkset.ALG("")) - requires.NoError(t, inner.KeyWrite(ctx, noAlg)) - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - - _, err := store.KeyRead(ctx, "noalg-read-deny") - requires.ErrorContains(t, err, "unsupported algorithm") - }) - }) - - t.Run("verify KeyReadAll", func(t *testing.T) { - t.Parallel() - - t.Run("filters to allowed algorithms", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) - disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) - requires.NoError(t, inner.KeyWrite(ctx, allowed)) - requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - - keys, err := store.KeyReadAll(ctx) - requires.NoError(t, err) - requires.Len(t, keys, 1) - m := keys[0].Marshal() - requires.Equal(t, allowed.Marshal().KID, m.KID) - }) - - t.Run("includes empty algorithm when allowed", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - inner2 := jwkset.NewMemoryStorage() - allowed2 := genRSAJWK(t, "rsa-2", jwkset.AlgRS256) - noAlg := genRSAJWK(t, "noalg-readall-allowed", jwkset.ALG("")) - requires.NoError(t, inner2.KeyWrite(ctx, allowed2)) - requires.NoError(t, inner2.KeyWrite(ctx, noAlg)) - store2, _ := NewValidationStore(zap.NewNop(), inner2, []string{"RS256"}, true) - keys2, err := store2.KeyReadAll(ctx) - requires.NoError(t, err) - requires.Len(t, keys2, 2) - }) - }) - - t.Run("verify KeyReplaceAll", func(t *testing.T) { - t.Parallel() - - t.Run("replaces with only allowed algorithms", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - allowed := genRSAJWK(t, "rsa-1", jwkset.AlgRS256) - disallowed := genHMACJWK(t, "hmac-1", jwkset.AlgHS256) - requires.NoError(t, store.KeyReplaceAll(ctx, []jwkset.JWK{allowed, disallowed})) - all, err := inner.KeyReadAll(ctx) - requires.NoError(t, err) - requires.Len(t, all, 1) - requires.Equal(t, allowed.Marshal().KID, all[0].Marshal().KID) - }) - - t.Run("includes empty algorithm on replace when allowed", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, true) - allowed := genRSAJWK(t, "rsa-3", jwkset.AlgRS256) - noAlg := genRSAJWK(t, "noalg-replace-allowed", jwkset.ALG("")) - requires.NoError(t, store.KeyReplaceAll(ctx, []jwkset.JWK{allowed, noAlg})) - all, err := inner.KeyReadAll(ctx) - requires.NoError(t, err) - requires.Len(t, all, 2) - }) - }) - - t.Run("verify KeyDelete", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - store, _ := NewValidationStore(zap.NewNop(), inner, nil, false) - - key := genRSAJWK(t, "rsa-del", jwkset.AlgRS256) - requires.NoError(t, store.KeyWrite(ctx, key)) - ok, err := store.KeyDelete(ctx, key.Marshal().KID) - requires.NoError(t, err) - requires.True(t, ok) - _, err = inner.KeyRead(ctx, key.Marshal().KID) - requires.ErrorContains(t, err, "not found") - }) - - t.Run("verify JSON", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) - disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) - requires.NoError(t, inner.KeyWrite(ctx, allowed)) - requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - - a, err := store.JSON(ctx) - requires.NoError(t, err) - b, err := inner.JSON(ctx) - requires.NoError(t, err) - requires.JSONEq(t, string(b), string(a)) - }) - - t.Run("verify JSONPublic", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) - disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) - requires.NoError(t, inner.KeyWrite(ctx, allowed)) - requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - - a, err := store.JSONPublic(ctx) - requires.NoError(t, err) - b, err := inner.JSONPublic(ctx) - requires.NoError(t, err) - requires.JSONEq(t, string(b), string(a)) - }) - - t.Run("verify JSONPrivate", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) - disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) - requires.NoError(t, inner.KeyWrite(ctx, allowed)) - requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - - a, err := store.JSONPrivate(ctx) - requires.NoError(t, err) - b, err := inner.JSONPrivate(ctx) - requires.NoError(t, err) - requires.JSONEq(t, string(b), string(a)) - }) - - t.Run("verify JSONWithOptions", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) - disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) - requires.NoError(t, inner.KeyWrite(ctx, allowed)) - requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - - a, err := store.JSONWithOptions(ctx, jwkset.JWKMarshalOptions{}, jwkset.JWKValidateOptions{}) - requires.NoError(t, err) - b, err := inner.JSONWithOptions(ctx, jwkset.JWKMarshalOptions{}, jwkset.JWKValidateOptions{}) - requires.NoError(t, err) - requires.JSONEq(t, string(b), string(a)) - }) - - t.Run("verify Marshal", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) - disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) - requires.NoError(t, inner.KeyWrite(ctx, allowed)) - requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - - ma, err := store.Marshal(ctx) - requires.NoError(t, err) - mb, err := inner.Marshal(ctx) - requires.NoError(t, err) - requires.Len(t, ma.Keys, len(mb.Keys)) - }) - - t.Run("verify MarshalWithOptions", func(t *testing.T) { - t.Parallel() - - ctx := context.Background() - inner := jwkset.NewMemoryStorage() - allowed := genRSAJWK(t, "rsa-json", jwkset.AlgRS256) - disallowed := genHMACJWK(t, "hmac-json", jwkset.AlgHS256) - requires.NoError(t, inner.KeyWrite(ctx, allowed)) - requires.NoError(t, inner.KeyWrite(ctx, disallowed)) - store, _ := NewValidationStore(zap.NewNop(), inner, []string{"RS256"}, false) - - ma, err := store.MarshalWithOptions(ctx, jwkset.JWKMarshalOptions{}, jwkset.JWKValidateOptions{}) - requires.NoError(t, err) - mb, err := inner.MarshalWithOptions(ctx, jwkset.JWKMarshalOptions{}, jwkset.JWKValidateOptions{}) - requires.NoError(t, err) - requires.Len(t, ma.Keys, len(mb.Keys)) - }) - - t.Run("verify ConstructorSupportedAlgorithms", func(t *testing.T) { - t.Parallel() - - inner := jwkset.NewMemoryStorage() - _, algs := NewValidationStore(zap.NewNop(), inner, nil, false) - requires.Len(t, algs, len(supportedAlgorithms)) - _, algs2 := NewValidationStore(zap.NewNop(), inner, []string{"RS256", "INVALID"}, false) - requires.ElementsMatch(t, []string{"RS256"}, algs2) - }) -} - -func genRSAJWK(t *testing.T, kid string, alg jwkset.ALG) jwkset.JWK { - t.Helper() - pk, err := rsa.GenerateKey(rand.Reader, 2048) - requires.NoError(t, err) - meta := jwkset.JWKMetadataOptions{KID: kid, USE: jwkset.UseSig} - if alg.String() != "" { - meta.ALG = alg - } - opts := jwkset.JWKOptions{ - Marshal: jwkset.JWKMarshalOptions{Private: false}, - Metadata: meta, - } - j, err := jwkset.NewJWKFromKey(pk, opts) - requires.NoError(t, err) - return j -} - -func genHMACJWK(t *testing.T, kid string, alg jwkset.ALG) jwkset.JWK { - t.Helper() - secret := make([]byte, 64) - _, err := rand.Read(secret) - requires.NoError(t, err) - // Use HMAC to derive a stable-length key material; any []byte works for JWK creation. - h := hmac.New(sha256.New, secret) - _, err = h.Write([]byte("test")) - requires.NoError(t, err) - key := h.Sum(nil) - opts := jwkset.JWKOptions{ - Marshal: jwkset.JWKMarshalOptions{Private: true}, - Metadata: jwkset.JWKMetadataOptions{ALG: alg, KID: kid, USE: jwkset.UseSig}, - } - j, err := jwkset.NewJWKFromKey(key, opts) - requires.NoError(t, err) - return j -} - -func genEd25519JWK(t *testing.T, kid string) jwkset.JWK { - t.Helper() - _, priv, err := ed25519.GenerateKey(rand.Reader) - requires.NoError(t, err) - opts := jwkset.JWKOptions{ - Marshal: jwkset.JWKMarshalOptions{Private: false}, - Metadata: jwkset.JWKMetadataOptions{ALG: jwkset.AlgEdDSA, KID: kid, USE: jwkset.UseSig}, - } - j, err := jwkset.NewJWKFromKey(priv, opts) - requires.NoError(t, err) - return j -} diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index fe3fd68c52..57591de31e 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -478,8 +478,7 @@ type JWKSConfiguration struct { KeyId string `yaml:"header_key_id"` // Common - Audiences []string `yaml:"audiences"` - AllowEmptyAlgorithm bool `yaml:"allow_empty_algorithm" envDefault:"false"` + Audiences []string `yaml:"audiences"` } type RefreshUnknownKID struct { diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index ef95998140..ba295c3935 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -489,8 +489,7 @@ "Secret": "", "Algorithm": "", "KeyId": "", - "Audiences": null, - "AllowEmptyAlgorithm": true + "Audiences": null }, { "URL": "https://example.com/.well-known/jwks2.json", @@ -508,8 +507,7 @@ "Secret": "", "Algorithm": "", "KeyId": "", - "Audiences": null, - "AllowEmptyAlgorithm": false + "Audiences": null }, { "URL": "https://example.com/.well-known/jwks3.json", @@ -524,8 +522,7 @@ "Secret": "", "Algorithm": "", "KeyId": "", - "Audiences": null, - "AllowEmptyAlgorithm": false + "Audiences": null } ], "HeaderName": "Authorization", From 1d571d0f59baf2807b9c07e74610ebaa0b472c7e Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 17:00:43 +0530 Subject: [PATCH 28/45] fix: cleanup --- router-tests/utils.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/router-tests/utils.go b/router-tests/utils.go index 825bcdb19e..09bf9b1fcc 100644 --- a/router-tests/utils.go +++ b/router-tests/utils.go @@ -42,10 +42,6 @@ func RequireSpanWithName(t *testing.T, exporter *tracetest2.InMemoryExporter, na return testSpan } -type ConfigureAuthOpts struct { - AllowEmptyAlgorithm bool -} - func ConfigureAuth(t *testing.T) ([]authentication.Authenticator, *jwks.Server) { authServer, err := jwks.NewServer(t) require.NoError(t, err) From b45a26379222974f769bd5c7d494dce29525b24c Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 17:06:37 +0530 Subject: [PATCH 29/45] fix: cleanup --- router-tests/jwks/crypto.go | 5 +---- .../pkg/authentication/jwks_token_decoder.go | 22 +++++-------------- 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/router-tests/jwks/crypto.go b/router-tests/jwks/crypto.go index 403bd5af39..0b157a8440 100644 --- a/router-tests/jwks/crypto.go +++ b/router-tests/jwks/crypto.go @@ -43,14 +43,11 @@ func (b *baseCrypto) MarshalJWK() (jwkset.JWK, error) { } meta := jwkset.JWKMetadataOptions{ + ALG: b.alg, KID: b.kID, USE: jwkset.UseSig, } - if b.alg != "" { - meta.ALG = b.alg - } - options := jwkset.JWKOptions{ Marshal: marshalOptions, Metadata: meta, diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index 27f70cbc78..dd401ee1dc 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -68,14 +68,9 @@ type audKey struct { type audienceSet map[string]struct{} -type keyFuncWithOpts struct { - keyFunc keyfunc.Keyfunc - allowedAlgorithms []string -} - func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKSConfig) (TokenDecoder, error) { audiencesMap := make(map[audKey]audienceSet, len(configs)) - keyFuncMap := make(map[audKey]keyFuncWithOpts, len(configs)) + keyFuncMap := make(map[audKey]keyfunc.Keyfunc, len(configs)) for _, c := range configs { if c.URL != "" { @@ -123,10 +118,7 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS if err != nil { return nil, err } - keyFuncMap[key] = keyFuncWithOpts{ - keyFunc: jwks, - allowedAlgorithms: c.AllowedAlgorithms, - } + keyFuncMap[key] = jwks } else if c.Secret != "" { key := audKey{kid: c.KeyId} @@ -176,18 +168,14 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS if err != nil { return nil, err } - keyFuncMap[key] = keyFuncWithOpts{ - keyFunc: jwks, - allowedAlgorithms: []string{c.Algorithm}, - } + keyFuncMap[key] = jwks } } keyFuncWrapper := jwt.Keyfunc(func(token *jwt.Token) (any, error) { var errJoin error - for key, keyFuncAndOpts := range keyFuncMap { - - pub, err := keyFuncAndOpts.keyFunc.Keyfunc(token) + for key, keyFunc := range keyFuncMap { + pub, err := keyFunc.Keyfunc(token) if err != nil { errJoin = errors.Join(errJoin, err) continue From eda41f547dc4911716848b5f10565025a916f3dc Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 17:16:22 +0530 Subject: [PATCH 30/45] fix: cleanup --- router/pkg/config/config.schema.json | 51 +++++++--------------------- router/pkg/config/fixtures/full.yaml | 1 - 2 files changed, 12 insertions(+), 40 deletions(-) diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 87e924d0ee..570d5f8086 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1687,10 +1687,6 @@ "type": "string", "description": "The KID header of the JWK token created using the secret" }, - "allow_empty_algorithm": { - "type": "boolean", - "description": "This attribute can be enabled to allow for the JWK to contain keys with empty algorithms" - }, "algorithms": { "type": "array", "description": "The allowed algorithms for the keys that are retrieved from the JWKs. An empty list means that all algorithms are allowed.", @@ -1758,40 +1754,20 @@ }, "oneOf": [ { - "allOf": [ - { - "required": ["url"], - "not": { - "anyOf": [ - { - "required": ["secret"] - }, - { - "required": ["symmetric_algorithm"] - }, - { - "required": ["header_key_id"] - } - ] - } - }, - { - "if": { - "required": ["allow_empty_algorithm"], - "properties": { - "allow_empty_algorithm": { "const": true } - } + "required": ["url"], + "not": { + "anyOf": [ + { + "required": ["secret"] }, - "then": { - "required": ["algorithms"], - "properties": { - "algorithms": { - "minItems": 1 - } - } + { + "required": ["symmetric_algorithm"] + }, + { + "required": ["header_key_id"] } - } - ] + ] + } }, { "required": ["secret", "symmetric_algorithm", "header_key_id"], @@ -1805,9 +1781,6 @@ }, { "required": ["refresh_interval"] - }, - { - "required": ["allow_empty_algorithm"] } ] } diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index 7121f81639..a43691cc12 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -272,7 +272,6 @@ authentication: - url: 'https://example.com/.well-known/jwks.json' refresh_interval: 1m algorithms: ['RS256'] - allow_empty_algorithm: true - url: 'https://example.com/.well-known/jwks2.json' refresh_interval: 2m algorithms: ['RS256', 'ES256'] From ab80adbe4056ccc665cc4265b2d7e1100741c958 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 19:23:16 +0530 Subject: [PATCH 31/45] fix: external dependency --- router-tests/go.mod | 1 + router-tests/go.sum | 2 + router-tests/jwks/jwks.go | 9 +- router/go.mod | 2 + router/go.sum | 2 + .../pkg/authentication/jwks_token_decoder.go | 3 +- router/pkg/authentication/keyfunc/keyfunc.go | 300 ----------------- .../authentication/keyfunc/keyfunc_test.go | 302 ------------------ 8 files changed, 11 insertions(+), 610 deletions(-) delete mode 100644 router/pkg/authentication/keyfunc/keyfunc.go delete mode 100644 router/pkg/authentication/keyfunc/keyfunc_test.go diff --git a/router-tests/go.mod b/router-tests/go.mod index 9d2d9bb3b9..7b3061f68c 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -45,6 +45,7 @@ require ( connectrpc.com/connect v1.16.2 // indirect github.com/99designs/gqlgen v0.17.76 // indirect github.com/KimMachineGun/automemlimit v0.6.1 // indirect + github.com/MicahParks/keyfunc/v3 v3.3.5 // indirect github.com/agnivade/levenshtein v1.2.1 // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect diff --git a/router-tests/go.sum b/router-tests/go.sum index 39cb22207b..8ba0cbab58 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -7,6 +7,8 @@ github.com/KimMachineGun/automemlimit v0.6.1 h1:ILa9j1onAAMadBsyyUJv5cack8Y1WT26 github.com/KimMachineGun/automemlimit v0.6.1/go.mod h1:T7xYht7B8r6AG/AqFcUdc7fzd2bIdBKmepfP2S1svPY= github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= +github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo= +github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8= github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo= github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y= github.com/agnivade/levenshtein v1.2.1 h1:EHBY3UOn1gwdy/VbFwgo4cxecRznFk7fKWN1KOX7eoM= diff --git a/router-tests/jwks/jwks.go b/router-tests/jwks/jwks.go index 5b329ae475..9e0cabcc8e 100644 --- a/router-tests/jwks/jwks.go +++ b/router-tests/jwks/jwks.go @@ -50,17 +50,14 @@ func (s *Server) TokenWithOpts(claims map[string]any, tokenOpts TokenOpts) (stri } for kid, pr := range s.providers { - var token *jwt.Token + method := pr.SigningMethod() if tokenOpts.AlgOverride != "" { - method := jwt.GetSigningMethod(tokenOpts.AlgOverride) + method = jwt.GetSigningMethod(tokenOpts.AlgOverride) if method == nil { return "", fmt.Errorf("unsupported signing method: %s", tokenOpts.AlgOverride) } - token = jwt.NewWithClaims(method, jwt.MapClaims(claims)) - } else { - token = jwt.NewWithClaims(pr.SigningMethod(), jwt.MapClaims(claims)) } - + token := jwt.NewWithClaims(method, jwt.MapClaims(claims)) token.Header[jwkset.HeaderKID] = kid return token.SignedString(pr.PrivateKey()) } diff --git a/router/go.mod b/router/go.mod index 00e4d36b3f..a199025799 100644 --- a/router/go.mod +++ b/router/go.mod @@ -4,6 +4,7 @@ go 1.25 require ( connectrpc.com/connect v1.16.2 + github.com/MicahParks/keyfunc/v3 v3.3.5 github.com/andybalholm/brotli v1.1.0 // indirect github.com/buger/jsonparser v1.1.1 github.com/cespare/xxhash/v2 v2.3.0 @@ -173,6 +174,7 @@ require ( // Do not upgrade, it renames attributes we rely on replace ( + github.com/MicahParks/keyfunc/v3 => github.com/wundergraph/keyfunc/v3 v3.0.0-20250922133930-92f21becf3d9 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp => go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 go.opentelemetry.io/contrib/propagators/b3 => go.opentelemetry.io/contrib/propagators/b3 v1.23.0 go.opentelemetry.io/contrib/propagators/jaeger => go.opentelemetry.io/contrib/propagators/jaeger v1.23.0 diff --git a/router/go.sum b/router/go.sum index cbf8c86cc3..1d4cb2afbd 100644 --- a/router/go.sum +++ b/router/go.sum @@ -317,6 +317,8 @@ github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTB github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.226 h1:3g6KNCG4ydgnpZnIlCK7pmtv0FSge6ILUS5LjrNZNiI= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.226/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= +github.com/wundergraph/keyfunc/v3 v3.0.0-20250922133930-92f21becf3d9 h1:7bPpsPUUxy5dEnuDSy2q3PAmflxqKx9vnyaTj3TSMBo= +github.com/wundergraph/keyfunc/v3 v3.0.0-20250922133930-92f21becf3d9/go.mod h1:el0U1ewqJ/T/Urlt3wImfmuBmoQdjL5yoNQ5e/+O98M= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index dd401ee1dc..ce19603bc9 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -4,11 +4,10 @@ import ( "context" "errors" "fmt" + "github.com/MicahParks/keyfunc/v3" "net/http" "time" - "github.com/wundergraph/cosmo/router/pkg/authentication/keyfunc" - "golang.org/x/time/rate" "github.com/MicahParks/jwkset" diff --git a/router/pkg/authentication/keyfunc/keyfunc.go b/router/pkg/authentication/keyfunc/keyfunc.go deleted file mode 100644 index 1c37d34b53..0000000000 --- a/router/pkg/authentication/keyfunc/keyfunc.go +++ /dev/null @@ -1,300 +0,0 @@ -// This is forked from https://github.com/MicahParks/keyfunc/blob/main/keyfunc.go -// Copyrights go to the original author. -package keyfunc - -import ( - "context" - "crypto" - "encoding/json" - "errors" - "fmt" - "log/slog" - "slices" - "time" - - "github.com/MicahParks/jwkset" - "github.com/golang-jwt/jwt/v5" - "golang.org/x/time/rate" -) - -var ( - // ErrKeyfunc is returned when a keyfunc error occurs. - ErrKeyfunc = errors.New("failed keyfunc") -) - -// Keyfunc is meant to be used as the jwt.Keyfunc function for github.com/golang-jwt/jwt/v5. It uses -// github.com/MicahParks/jwkset as a JWK Set storage. -type Keyfunc interface { - Keyfunc(token *jwt.Token) (any, error) - KeyfuncCtx(ctx context.Context) jwt.Keyfunc - Storage() jwkset.Storage - VerificationKeySet(ctx context.Context) (jwt.VerificationKeySet, error) -} - -// Options are used to create a new Keyfunc. -type Options struct { - Ctx context.Context - Storage jwkset.Storage - UseWhitelist []jwkset.USE - - // Custom Non Base on original keyfunc - AllowedAlgorithms []string -} - -// Override is used to change specific default behaviors. -type Override struct { - // HTTPTimeout is from https://pkg.go.dev/github.com/MicahParks/jwkset#HTTPClientStorageOptions - HTTPTimeout time.Duration - // RateLimitWaitMax is from https://pkg.go.dev/github.com/MicahParks/jwkset#HTTPClientOptions - RateLimitWaitMax time.Duration - // RefreshErrorHandlerFunc is a function that accepts the URL of the remote JWK Set storage and returns the - // RefreshErrorHandler from https://pkg.go.dev/github.com/MicahParks/jwkset#HTTPClientStorageOptions - RefreshErrorHandlerFunc func(u string) func(ctx context.Context, err error) - // RefreshInterval is from https://pkg.go.dev/github.com/MicahParks/jwkset#HTTPClientStorageOptions - RefreshInterval time.Duration - // RefreshUnknownKID is from https://pkg.go.dev/github.com/MicahParks/jwkset#HTTPClientOptions - RefreshUnknownKID *rate.Limiter - // ValidationSkipAll is copied to SkipAll in https://pkg.go.dev/github.com/MicahParks/jwkset#JWKValidateOptions - ValidationSkipAll bool -} - -type keyfunc struct { - ctx context.Context - storage jwkset.Storage - useWhitelist []jwkset.USE - allowedAlgorithms []string -} - -// New creates a new Keyfunc. -func New(options Options) (Keyfunc, error) { - ctx := options.Ctx - if ctx == nil { - ctx = context.Background() - } - if options.Storage == nil { - return nil, fmt.Errorf("%w: no JWK Set storage given in options", ErrKeyfunc) - } - k := keyfunc{ - ctx: ctx, - storage: options.Storage, - useWhitelist: options.UseWhitelist, - allowedAlgorithms: options.AllowedAlgorithms, - } - return k, nil -} - -// NewDefault creates a new Keyfunc with a default JWK Set storage and options. -// -// This will launch "refresh goroutine" to automatically refresh the remote HTTP resources. -func NewDefault(urls []string) (Keyfunc, error) { - return NewDefaultCtx(context.Background(), urls) -} - -// NewDefaultCtx creates a new Keyfunc with a default JWK Set storage and options. The context is used to end the -// "refresh goroutine". -// -// This will launch "refresh goroutine" to automatically refresh the remote HTTP resources. -func NewDefaultCtx(ctx context.Context, urls []string) (Keyfunc, error) { - client, err := jwkset.NewDefaultHTTPClientCtx(ctx, urls) - if err != nil { - return nil, err - } - options := Options{ - Storage: client, - } - return New(options) -} - -// NewDefaultOverrideCtx creates a new Keyfunc with a default JWK Set storage and options. The context is used to end -// the "refresh goroutine". The override parameter is used to change specific default behaviors. -// -// This will launch "refresh goroutine" to automatically refresh remote HTTP resources. -func NewDefaultOverrideCtx(ctx context.Context, urls []string, override Override) (Keyfunc, error) { - rateLimitWaitMax := time.Minute - if override.RateLimitWaitMax != 0 { - rateLimitWaitMax = override.RateLimitWaitMax - } - refreshErrorHandler := func(u string) func(ctx context.Context, err error) { - return func(ctx context.Context, err error) { - slog.Default().ErrorContext(ctx, "Failed to refresh HTTP JWK Set from remote HTTP resource.", - "error", err, - "url", u, - ) - } - } - if override.RefreshErrorHandlerFunc != nil { - refreshErrorHandler = override.RefreshErrorHandlerFunc - } - refreshInterval := time.Hour - if override.RefreshInterval > 0 { - refreshInterval = override.RefreshInterval - } - refreshUnknownKID := rate.NewLimiter(rate.Every(5*time.Minute), 1) - if override.RefreshUnknownKID != nil { - refreshUnknownKID = override.RefreshUnknownKID - } - - clientOptions := jwkset.HTTPClientOptions{ - HTTPURLs: make(map[string]jwkset.Storage), - RateLimitWaitMax: rateLimitWaitMax, - RefreshUnknownKID: refreshUnknownKID, - } - for _, u := range urls { - errorHandler := refreshErrorHandler(u) - options := jwkset.HTTPClientStorageOptions{ - Ctx: ctx, - NoErrorReturnFirstHTTPReq: true, - RefreshErrorHandler: errorHandler, - RefreshInterval: refreshInterval, - ValidateOptions: jwkset.JWKValidateOptions{ - SkipAll: override.ValidationSkipAll, - }, - } - - if override.HTTPTimeout > 0 { - options.HTTPTimeout = override.HTTPTimeout - } - - c, err := jwkset.NewStorageFromHTTP(u, options) - if err != nil { - return nil, fmt.Errorf("failed to create HTTP client storage for %q: %w", u, errors.Join(err, jwkset.ErrNewClient)) - } - clientOptions.HTTPURLs[u] = c - } - storage, err := jwkset.NewHTTPClient(clientOptions) - if err != nil { - return nil, fmt.Errorf("failed to create HTTP client storage: %w", errors.Join(err, jwkset.ErrNewClient)) - } - options := Options{ - Ctx: ctx, - Storage: storage, - UseWhitelist: nil, - } - return New(options) -} - -// NewJWKJSON creates a new Keyfunc from raw JWK JSON. -func NewJWKJSON(raw json.RawMessage) (Keyfunc, error) { - marshalOptions := jwkset.JWKMarshalOptions{ - Private: true, - } - jwk, err := jwkset.NewJWKFromRawJSON(raw, marshalOptions, jwkset.JWKValidateOptions{}) - if err != nil { - return nil, fmt.Errorf("%w: could not create JWK from raw JSON", errors.Join(err, ErrKeyfunc)) - } - store := jwkset.NewMemoryStorage() - err = store.KeyWrite(context.Background(), jwk) - if err != nil { - return nil, fmt.Errorf("%w: could not write JWK to storage", errors.Join(err, ErrKeyfunc)) - } - options := Options{ - Storage: store, - } - return New(options) -} - -// NewJWKSetJSON creates a new Keyfunc from raw JWK Set JSON. -func NewJWKSetJSON(raw json.RawMessage) (Keyfunc, error) { - var jwks jwkset.JWKSMarshal - err := json.Unmarshal(raw, &jwks) - if err != nil { - return nil, fmt.Errorf("%w: could not unmarshal raw JWK Set JSON", errors.Join(err, ErrKeyfunc)) - } - store, err := jwks.ToStorage() - if err != nil { - return nil, fmt.Errorf("%w: could not create JWK Set storage", errors.Join(err, ErrKeyfunc)) - } - options := Options{ - Storage: store, - } - return New(options) -} - -func (k keyfunc) KeyfuncCtx(ctx context.Context) jwt.Keyfunc { - return func(token *jwt.Token) (any, error) { - kidInter, ok := token.Header[jwkset.HeaderKID] - if !ok { - return k.VerificationKeySet(ctx) - } - kid, ok := kidInter.(string) - if !ok { - return nil, fmt.Errorf("%w: could not convert kid in JWT header to string", ErrKeyfunc) - } - algInter, ok := token.Header["alg"] - if !ok { - return nil, fmt.Errorf("%w: could not find alg in JWT header", ErrKeyfunc) - } - alg, ok := algInter.(string) - if !ok { - // For test coverage purposes, this should be impossible to reach because the JWT package rejects a token - // without an alg parameter in the header before calling jwt.Keyfunc. - return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrKeyfunc) - } - - // When an algorithm is actually provided in the jwks the current keyfunc will validate the - // jwks algorithm with it. But when no algorithm is provided (alg: none or missing alg) - // the default keyfunc will not validate the algorithm as it has nothing to cross check. - if len(k.allowedAlgorithms) > 0 { - // This is a custom validation different from the original keyfunc.Keyfunc - if !slices.Contains(k.allowedAlgorithms, alg) { - return nil, fmt.Errorf("%w: could not find alg %s in allow list", ErrKeyfunc, alg) - } - } - - jwk, err := k.storage.KeyRead(ctx, kid) - if err != nil { - return nil, fmt.Errorf("%w: could not read JWK from storage", errors.Join(err, ErrKeyfunc)) - } - - if a := jwk.Marshal().ALG.String(); a != "" && a != alg { - return nil, fmt.Errorf(`%w: JWK "alg" parameter value %q does not match token "alg" parameter value %q`, ErrKeyfunc, a, alg) - } - if len(k.useWhitelist) > 0 { - found := false - for _, u := range k.useWhitelist { - if jwk.Marshal().USE == u { - found = true - break - } - } - if !found { - return nil, fmt.Errorf(`%w: JWK "use" parameter value %q is not in whitelist`, ErrKeyfunc, jwk.Marshal().USE) - } - } - - key := jwk.Key() - pk, ok := key.(publicKeyer) - if ok { - key = pk.Public() - } - - return key, nil - } -} -func (k keyfunc) Keyfunc(token *jwt.Token) (any, error) { - keyF := k.KeyfuncCtx(k.ctx) - return keyF(token) -} -func (k keyfunc) Storage() jwkset.Storage { - return k.storage -} -func (k keyfunc) VerificationKeySet(ctx context.Context) (jwt.VerificationKeySet, error) { - jwk, err := k.storage.KeyReadAll(ctx) - if err != nil { - return jwt.VerificationKeySet{}, fmt.Errorf("failed to read all JWK from storage: %w", errors.Join(err, ErrKeyfunc)) - } - var allKeys jwt.VerificationKeySet - for _, j := range jwk { - key := j.Key() - pk, ok := key.(publicKeyer) - if ok { - key = pk.Public() - } - allKeys.Keys = append(allKeys.Keys, key) - } - return allKeys, nil -} - -type publicKeyer interface { - Public() crypto.PublicKey -} diff --git a/router/pkg/authentication/keyfunc/keyfunc_test.go b/router/pkg/authentication/keyfunc/keyfunc_test.go deleted file mode 100644 index 1b2b28843f..0000000000 --- a/router/pkg/authentication/keyfunc/keyfunc_test.go +++ /dev/null @@ -1,302 +0,0 @@ -// This is forked from https://github.com/MicahParks/keyfunc/blob/main/keyfunc.go -// Copyrights go to the original author. -package keyfunc - -import ( - "context" - "crypto/ed25519" - "crypto/rand" - "encoding/json" - "errors" - "net/http" - "net/http/httptest" - "reflect" - "testing" - - "github.com/MicahParks/jwkset" - "github.com/golang-jwt/jwt/v5" -) - -const ( - keyID = "my-key-id" -) - -func TestNew(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - _, priv, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatalf("Failed to generate ED25519 key pair. Error: %s", err) - } - jwk, err := jwkset.NewJWKFromKey(priv, jwkset.JWKOptions{}) - if err != nil { - t.Fatalf("Failed to create JWK from ED25519 private key. Error: %s", err) - } - - serverStore := jwkset.NewMemoryStorage() - err = serverStore.KeyWrite(ctx, jwk) - if err != nil { - t.Fatalf("Failed to write ED25519 public key to server store. Error: %s", err) - } - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - rawJWKS, err := serverStore.JSONPrivate(ctx) - if err != nil { - t.Fatalf("Failed to get JWK Set JSON from server store. Error: %s", err) - } - _, _ = w.Write(rawJWKS) - })) - defer server.Close() - - token := jwt.New(jwt.SigningMethodEdDSA) - token.Header[jwkset.HeaderKID] = keyID - signed, err := token.SignedString(priv) - if err != nil { - t.Fatalf("Failed to sign JWT. Error: %s", err) - } - - clientStore, err := jwkset.NewDefaultHTTPClient([]string{server.URL}) - if err != nil { - t.Fatalf("Failed to create client store. Error: %s", err) - } - options := Options{ - Ctx: ctx, - Storage: clientStore, - UseWhitelist: []jwkset.USE{jwkset.UseSig}, - } - k, err := New(options) - if err != nil { - t.Fatalf("Failed to create keyfunc. Error: %s", err) - } - - _, err = jwt.Parse(signed, k.Keyfunc) - if !errors.Is(err, ErrKeyfunc) { - t.Fatalf("Expected ErrKeyfunc for missing Key ID in header, but got %s.", err) - } - - metadata := jwkset.JWKMetadataOptions{ - KID: keyID, - USE: jwkset.UseSig, - } - jwkOptions := jwkset.JWKOptions{ - Metadata: metadata, - } - jwk, err = jwkset.NewJWKFromKey(priv, jwkOptions) - if err != nil { - t.Fatalf("Failed to create JWK from ED25519 private key. Error: %s", err) - } - err = serverStore.KeyWrite(ctx, jwk) - if err != nil { - t.Fatalf("Failed to write ED25519 public key to server store. Error: %s", err) - } - - clientStore, err = jwkset.NewDefaultHTTPClient([]string{server.URL}) - if err != nil { - t.Fatalf("Failed to create client store. Error: %s", err) - } - options.Storage = clientStore - k, err = New(options) - if err != nil { - t.Fatalf("Failed to create keyfunc. Error: %s", err) - } - - _, err = jwt.Parse(signed, k.Keyfunc) - if err != nil { - t.Fatalf("Failed to parse JWT. Error: %s", err) - } - - if !reflect.DeepEqual(k.Storage(), clientStore) { - t.Fatalf("Expected client store, but got something else.") - } - - _, err = NewDefault([]string{server.URL}) - if err != nil { - t.Fatalf("Failed to create keyfunc. Error: %s", err) - } - - _, err = NewDefaultOverrideCtx(ctx, []string{server.URL}, Override{}) - if err != nil { - t.Fatalf("Failed to create keyfunc with overrides. Error: %s", err) - } -} - -func TestNewErr(t *testing.T) { - _, err := New(Options{}) - if !errors.Is(err, ErrKeyfunc) { - t.Error("Expected ErrKeyfunc, but got nil.") - } -} - -func TestNewJWKJSON(t *testing.T) { - // Get the JWK as JSON. - jwksJSON := json.RawMessage(`{"kty": "RSA","e": "AQAB","kid": "ee8d626d","n": "gRda5b0pkgTytDuLrRnNSYhvfMIyM0ASq2ZggY4dVe12JV8N7lyXilyqLKleD-2lziivvzE8O8CdIC2vUf0tBD7VuMyldnZruSEZWCuKJPdgKgy9yPpShmD2NyhbwQIAbievGMJIp_JMwz8MkdY5pzhPECGNgCEtUAmsrrctP5V8HuxaxGt9bb-DdPXkYWXW3MPMSlVpGZ5GiIeTABxqYNG2MSoYeQ9x8O3y488jbassTqxExI_4w9MBQBJR9HIXjWrrrenCcDlMY71rzkbdj3mmcn9xMq2vB5OhfHyHTihbUPLSm83aFWSuW9lE7ogMc93XnrB8evIAk6VfsYlS9Q"}`) - - // Create the keyfunc.Keyfunc. - k, err := NewJWKJSON(jwksJSON) - if err != nil { - t.Fatalf("Failed to create a keyfunc.Keyfunc.\nError: %s", err) - } - - // Get a JWT to parse. - jwtB64 := "eyJraWQiOiJlZThkNjI2ZCIsInR5cCI6IkpXVCIsImFsZyI6IlJTMjU2In0.eyJzdWIiOiJXZWlkb25nIiwiYXVkIjoiVGFzaHVhbiIsImlzcyI6Imp3a3Mtc2VydmljZS5hcHBzcG90LmNvbSIsImlhdCI6MTYzMTM2OTk1NSwianRpIjoiNDY2M2E5MTAtZWU2MC00NzcwLTgxNjktY2I3NDdiMDljZjU0In0.LwD65d5h6U_2Xco81EClMa_1WIW4xXZl8o4b7WzY_7OgPD2tNlByxvGDzP7bKYA9Gj--1mi4Q4li4CAnKJkaHRYB17baC0H5P9lKMPuA6AnChTzLafY6yf-YadA7DmakCtIl7FNcFQQL2DXmh6gS9J6TluFoCIXj83MqETbDWpL28o3XAD_05UP8VLQzH2XzyqWKi97mOuvz-GsDp9mhBYQUgN3csNXt2v2l-bUPWe19SftNej0cxddyGu06tXUtaS6K0oe0TTbaqc3hmfEiu5G0J8U6ztTUMwXkBvaknE640NPgMQJqBaey0E4u0txYgyvMvvxfwtcOrDRYqYPBnA" - - // Parse the JWT. - token, err := jwt.Parse(jwtB64, k.Keyfunc) - if err != nil { - t.Fatalf("Failed to parse the JWT.\nError: %s", err) - } - - // Check if the token is valid. - if !token.Valid { - t.Fatalf("The token is not valid.") - } -} - -func TestNewJWKSetJSON(t *testing.T) { - // Get the JWK Set as JSON. - jwksJSON := json.RawMessage(`{"keys":[{"kty":"RSA","e":"AQAB","kid":"ee8d626d","n":"gRda5b0pkgTytDuLrRnNSYhvfMIyM0ASq2ZggY4dVe12JV8N7lyXilyqLKleD-2lziivvzE8O8CdIC2vUf0tBD7VuMyldnZruSEZWCuKJPdgKgy9yPpShmD2NyhbwQIAbievGMJIp_JMwz8MkdY5pzhPECGNgCEtUAmsrrctP5V8HuxaxGt9bb-DdPXkYWXW3MPMSlVpGZ5GiIeTABxqYNG2MSoYeQ9x8O3y488jbassTqxExI_4w9MBQBJR9HIXjWrrrenCcDlMY71rzkbdj3mmcn9xMq2vB5OhfHyHTihbUPLSm83aFWSuW9lE7ogMc93XnrB8evIAk6VfsYlS9Q"},{"kty":"EC","crv":"P-256","kid":"711d48d1","x":"tfXCoBU-wXemeQCkME1gMZWK0-UECCHIkedASZR0t-Q","y":"9xzYtnKQdiQJHCtGwpZWF21eP1fy5x4wC822rCilmBw"},{"kty":"EC","crv":"P-384","kid":"d52c9829","x":"tFx6ev6eLs9sNfdyndn4OgbhV6gPFVn7Ul0VD5vwuplJLbIYeFLI6T42tTaE5_Q4","y":"A0gzB8TqxPX7xMzyHH_FXkYG2iROANH_kQxBovSeus6l_QSyqYlipWpBy9BhY9dz"},{"kty":"RSA","e":"AQAB","kid":"ecac72e5","n":"nLbnTvZAUxdmuAbDDUNAfha6mw0fri3UpV2w1PxilflBuSnXJhzo532-YQITogoanMjy_sQ8kHUhZYHVRR6vLZRBBbl-hP8XWiCe4wwioy7Ey3TiIUYfW-SD6I42XbLt5o-47IR0j5YDXxnX2UU7-UgR_kITBeLDfk0rSp4B0GUhPbP5IDItS0MHHDDS3lhvJomxgEfoNrp0K0Fz_s0K33hfOqc2hD1tSkX-3oDTQVRMF4Nxax3NNw8-ahw6HNMlXlwWfXodgRMvj9pcz8xUYa3C5IlPlZkMumeNCFx1qds6K_eYcU0ss91DdbhhE8amRX1FsnBJNMRUkA5i45xkOIx15rQN230zzh0p71jvtx7wYRr5pdMlwxV0T9Ck5PCmx-GzFazA2X6DJ0Xnn1-cXkRoZHFj_8Mba1dUrNz-NWEk83uW5KT-ZEbX7nzGXtayKWmGb873a8aYPqIsp6bQ_-eRBd8TDT2g9HuPyPr5VKa1p33xKaohz4DGy3t1Qpy3UWnbPXUlh5dLWPKz-TcS9FP5gFhWVo-ZhU03Pn6P34OxHmXGWyQao18dQGqzgD4e9vY3rLhfcjVZJYNlWY2InsNwbYS-DnienPf1ws-miLeXxNKG3tFydoQzHwyOxG6Wc-HBfzL_hOvxINKQamvPasaYWl1LWznMps6elKCgKDc"},{"kty":"EC","crv":"P-521","kid":"c570888f","x":"AHNpXq0J7rikNRlwhaMYDD8LGVAVJzNJ-jEPksUIn2LB2LCdNRzfAhgbxdQcWT9ktlc9M1EhmTLccEqfnWdGL9G1","y":"AfHPUW3GYzzqbTczcYR0nYMVMFVrYsUxv4uiuSNV_XRN3Jf8zeYbbOLJv4S3bUytO7qHY8bfZxPxR9nn3BBTf5ol"}]}`) - - // Create the keyfunc.Keyfunc. - k, err := NewJWKSetJSON(jwksJSON) - if err != nil { - t.Fatalf("Failed to create a keyfunc.Keyfunc.\nError: %s", err) - } - - // Get a JWT to parse. - jwtB64 := "eyJraWQiOiJlZThkNjI2ZCIsInR5cCI6IkpXVCIsImFsZyI6IlJTMjU2In0.eyJzdWIiOiJXZWlkb25nIiwiYXVkIjoiVGFzaHVhbiIsImlzcyI6Imp3a3Mtc2VydmljZS5hcHBzcG90LmNvbSIsImlhdCI6MTYzMTM2OTk1NSwianRpIjoiNDY2M2E5MTAtZWU2MC00NzcwLTgxNjktY2I3NDdiMDljZjU0In0.LwD65d5h6U_2Xco81EClMa_1WIW4xXZl8o4b7WzY_7OgPD2tNlByxvGDzP7bKYA9Gj--1mi4Q4li4CAnKJkaHRYB17baC0H5P9lKMPuA6AnChTzLafY6yf-YadA7DmakCtIl7FNcFQQL2DXmh6gS9J6TluFoCIXj83MqETbDWpL28o3XAD_05UP8VLQzH2XzyqWKi97mOuvz-GsDp9mhBYQUgN3csNXt2v2l-bUPWe19SftNej0cxddyGu06tXUtaS6K0oe0TTbaqc3hmfEiu5G0J8U6ztTUMwXkBvaknE640NPgMQJqBaey0E4u0txYgyvMvvxfwtcOrDRYqYPBnA" - - // Parse the JWT. - token, err := jwt.Parse(jwtB64, k.Keyfunc) - if err != nil { - t.Fatalf("Failed to parse the JWT.\nError: %s", err) - } - - // Check if the token is valid. - if !token.Valid { - t.Fatalf("The token is not valid.") - } -} - -func TestVerificationKeySet(t *testing.T) { - ctx := context.Background() - _, priv, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatalf("Failed to generate ED25519 key pair: %v", err) - } - jwk, err := jwkset.NewJWKFromKey(priv, jwkset.JWKOptions{}) - if err != nil { - t.Fatalf("Failed to create JWK: %v", err) - } - store := jwkset.NewMemoryStorage() - err = store.KeyWrite(ctx, jwk) - if err != nil { - t.Fatalf("Failed to write JWK: %v", err) - } - k, err := New(Options{Ctx: ctx, Storage: store}) - if err != nil { - t.Fatalf("Failed to create Keyfunc: %v", err) - } - vks, err := k.VerificationKeySet(ctx) - if err != nil { - t.Fatalf("VerificationKeySet failed: %v", err) - } - if len(vks.Keys) != 1 { - t.Fatalf("Expected 1 key, got %d", len(vks.Keys)) - } -} - -func TestNoKIDHeaderCallsVerificationKeySet(t *testing.T) { - ctx := context.Background() - - // Generate two key pairs. - _, priv1, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatalf("Failed to generate ED25519 key pair 1: %v", err) - } - _, priv2, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatalf("Failed to generate ED25519 key pair 2: %v", err) - } - - jwk1, err := jwkset.NewJWKFromKey(priv1, jwkset.JWKOptions{}) - if err != nil { - t.Fatalf("Failed to create JWK 1: %v", err) - } - jwk2, err := jwkset.NewJWKFromKey(priv2, jwkset.JWKOptions{}) - if err != nil { - t.Fatalf("Failed to create JWK 2: %v", err) - } - - orders := [][]jwkset.JWK{ - {jwk1, jwk2}, - {jwk2, jwk1}, - } - privs := []ed25519.PrivateKey{priv1, priv2} - - for i, order := range orders { - store := jwkset.NewMemoryStorage() - for _, jwk := range order { - err = store.KeyWrite(ctx, jwk) - if err != nil { - t.Fatalf("Failed to write JWK: %v", err) - } - } - k, err := New(Options{Ctx: ctx, Storage: store}) - if err != nil { - t.Fatalf("Failed to create Keyfunc: %v", err) - } - // Sign a token with the corresponding private key (no KID header) - token := jwt.New(jwt.SigningMethodEdDSA) - tokenString, err := token.SignedString(privs[i]) - if err != nil { - t.Fatalf("Failed to sign token: %v", err) - } - parsedToken, err := jwt.Parse(tokenString, k.KeyfuncCtx(ctx)) - if err != nil { - t.Fatalf("Parse failed (order %d): %v", i+1, err) - } - if !parsedToken.Valid { - t.Fatalf("Expected token to be valid (order %d)", i+1) - } - } -} - -func TestNoKIDHeaderNoMatchingJWK(t *testing.T) { - ctx := context.Background() - - _, missingFromSet, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatalf("Failed to generate ED25519 key pair: %v", err) - } - - _, presentInSet, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatalf("Failed to generate other ED25519 key pair: %v", err) - } - jwk, err := jwkset.NewJWKFromKey(presentInSet, jwkset.JWKOptions{}) - if err != nil { - t.Fatalf("Failed to create JWK: %v", err) - } - store := jwkset.NewMemoryStorage() - err = store.KeyWrite(ctx, jwk) - if err != nil { - t.Fatalf("Failed to write JWK: %v", err) - } - - k, err := New(Options{Ctx: ctx, Storage: store}) - if err != nil { - t.Fatalf("Failed to create Keyfunc: %v", err) - } - - token := jwt.New(jwt.SigningMethodEdDSA) - tokenString, err := token.SignedString(missingFromSet) - if err != nil { - t.Fatalf("Failed to sign token: %v", err) - } - - _, err = jwt.Parse(tokenString, k.KeyfuncCtx(ctx)) - if err == nil { - t.Fatalf("Expected error due to no matching JWK, but got none") - } -} From b3a4a4a407641354c38dc95bed2b154c6dda3df8 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 19:36:58 +0530 Subject: [PATCH 32/45] fix: go mod tidy --- router-tests/go.mod | 2 +- router-tests/go.sum | 4 ++-- router/go.mod | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/router-tests/go.mod b/router-tests/go.mod index 7b3061f68c..9fceb9f617 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -45,7 +45,6 @@ require ( connectrpc.com/connect v1.16.2 // indirect github.com/99designs/gqlgen v0.17.76 // indirect github.com/KimMachineGun/automemlimit v0.6.1 // indirect - github.com/MicahParks/keyfunc/v3 v3.3.5 // indirect github.com/agnivade/levenshtein v1.2.1 // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect @@ -151,6 +150,7 @@ require ( github.com/vbatts/tar-split v0.12.1 // indirect github.com/vektah/gqlparser/v2 v2.5.30 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/wundergraph/keyfunc/v3 v3.0.0-20250922133930-92f21becf3d9 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect diff --git a/router-tests/go.sum b/router-tests/go.sum index 8ba0cbab58..0b8d19ce7f 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -7,8 +7,6 @@ github.com/KimMachineGun/automemlimit v0.6.1 h1:ILa9j1onAAMadBsyyUJv5cack8Y1WT26 github.com/KimMachineGun/automemlimit v0.6.1/go.mod h1:T7xYht7B8r6AG/AqFcUdc7fzd2bIdBKmepfP2S1svPY= github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= -github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo= -github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8= github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo= github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y= github.com/agnivade/levenshtein v1.2.1 h1:EHBY3UOn1gwdy/VbFwgo4cxecRznFk7fKWN1KOX7eoM= @@ -354,6 +352,8 @@ github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301 h1:EzfKHQoT github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301/go.mod h1:wxI0Nak5dI5RvJuzGyiEK4nZj0O9X+Aw6U0tC1wPKq0= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.226 h1:3g6KNCG4ydgnpZnIlCK7pmtv0FSge6ILUS5LjrNZNiI= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.226/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= +github.com/wundergraph/keyfunc/v3 v3.0.0-20250922133930-92f21becf3d9 h1:7bPpsPUUxy5dEnuDSy2q3PAmflxqKx9vnyaTj3TSMBo= +github.com/wundergraph/keyfunc/v3 v3.0.0-20250922133930-92f21becf3d9/go.mod h1:el0U1ewqJ/T/Urlt3wImfmuBmoQdjL5yoNQ5e/+O98M= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/router/go.mod b/router/go.mod index a199025799..2fd3bb1abf 100644 --- a/router/go.mod +++ b/router/go.mod @@ -4,7 +4,6 @@ go 1.25 require ( connectrpc.com/connect v1.16.2 - github.com/MicahParks/keyfunc/v3 v3.3.5 github.com/andybalholm/brotli v1.1.0 // indirect github.com/buger/jsonparser v1.1.1 github.com/cespare/xxhash/v2 v2.3.0 @@ -79,6 +78,7 @@ require ( github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 github.com/tonglil/opentelemetry-go-datadog-propagator v0.1.3 github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 + github.com/wundergraph/keyfunc/v3 v3.0.0-20250922133930-92f21becf3d9 go.uber.org/goleak v1.3.0 go.uber.org/ratelimit v0.3.1 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 @@ -174,7 +174,6 @@ require ( // Do not upgrade, it renames attributes we rely on replace ( - github.com/MicahParks/keyfunc/v3 => github.com/wundergraph/keyfunc/v3 v3.0.0-20250922133930-92f21becf3d9 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp => go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 go.opentelemetry.io/contrib/propagators/b3 => go.opentelemetry.io/contrib/propagators/b3 v1.23.0 go.opentelemetry.io/contrib/propagators/jaeger => go.opentelemetry.io/contrib/propagators/jaeger v1.23.0 From 8b4f77c6cbdc6301a8dd78ec891b1f69c47d8719 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 19:37:06 +0530 Subject: [PATCH 33/45] fix: update dependency --- router/pkg/authentication/jwks_token_decoder.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index ce19603bc9..f472577ea1 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "github.com/MicahParks/keyfunc/v3" + "github.com/wundergraph/keyfunc/v3" "net/http" "time" From 4ec706df60f5d0de4111d79f86cb2c2d35df5242 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 20:50:52 +0530 Subject: [PATCH 34/45] fix: tests --- router-tests/authentication_test.go | 154 ++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 280130c87b..db7e01237e 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -2,11 +2,15 @@ package integration import ( "bytes" + "crypto/ed25519" "crypto/rsa" "crypto/x509" + "encoding/base64" + "encoding/json" "encoding/pem" "io" "net/http" + "net/http/httptest" "strings" "sync" "sync/atomic" @@ -2358,6 +2362,117 @@ func TestSupportedAlgorithms(t *testing.T) { }) } +func TestJWKSIgnoreUnsupportedKeys(t *testing.T) { + t.Parallel() + + // Create one supported RSA key and one unsupported OKP(Ed448) entry in the same JWKS. + rsaProvider, err := jwks.NewRSACrypto("", jwkset.AlgRS256, 2048) + require.NoError(t, err) + + edProvider, err := jwks.NewED25519Crypto("") + require.NoError(t, err) + + rsaKID := rsaProvider.KID() + unsupportedKID := edProvider.KID() + "-unsupported" + + // Build RSA public JWK + rsaPriv := rsaProvider.PrivateKey().(*rsa.PrivateKey) + rsaPub := rsaPriv.PublicKey + nB64 := base64.RawURLEncoding.EncodeToString(rsaPub.N.Bytes()) + eB64 := base64.RawURLEncoding.EncodeToString(bigIntBytes(int64(rsaPub.E))) + + rsaJWK := map[string]any{ + "kty": "RSA", + "n": nB64, + "e": eB64, + "use": "sig", + "alg": jwkset.AlgRS256.String(), + "kid": rsaKID, + } + + // Build unsupported OKP Ed448 JWK (intentionally unsupported curve) + edPriv := edProvider.PrivateKey().(ed25519.PrivateKey) + edPub := edPriv.Public().(ed25519.PublicKey) + xB64 := base64.RawURLEncoding.EncodeToString(edPub) + unsupportedJWK := map[string]any{ + "kty": "OKP", + "crv": "Ed448", + "x": xB64, + "use": "sig", + "alg": jwkset.AlgEdDSA.String(), + "kid": unsupportedKID, + } + + jwksJSON := map[string]any{ + "keys": []any{rsaJWK, unsupportedJWK}, + } + + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(jwksJSON) + }) + ts := httptest.NewServer(mux) + t.Cleanup(ts.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: ts.URL + "/.well-known/jwks.json", + RefreshInterval: time.Second * 5, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // 1) Supported RSA token should succeed + rsaToken := jwt.New(jwt.SigningMethodRS256) + rsaToken.Header[jwkset.HeaderKID] = rsaKID + signedRSA, err := rsaToken.SignedString(rsaPriv) + require.NoError(t, err) + + header := http.Header{"Authorization": []string{"Bearer " + signedRSA}} + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res.Body.Close() }() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(body)) + + // 2) Token with KID pointing to an unsupported Ed448 key in JWKS should fail + edToken := jwt.New(jwt.SigningMethodEdDSA) + edToken.Header[jwkset.HeaderKID] = unsupportedKID + signedEd, err := edToken.SignedString(edPriv) + require.NoError(t, err) + + header2 := http.Header{"Authorization": []string{"Bearer " + signedEd}} + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header2, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + body2, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(body2)) + }) +} + +// bigIntBytes converts an int64 to big-endian bytes without leading zeros trimmed by base64 encoder behavior. +func bigIntBytes(v int64) []byte { + // Minimal big-endian bytes for the exponent + b := make([]byte, 0, 8) + for v > 0 { + b = append([]byte{byte(v & 0xff)}, b...) + v >>= 8 + } + if len(b) == 0 { + return []byte{0} + } + return b +} + func TestAuthenticationOverWebsocket(t *testing.T) { t.Parallel() @@ -2912,6 +3027,45 @@ func TestAudienceValidation(t *testing.T) { require.Equal(t, http.StatusUnauthorized, res2.StatusCode) }) }) + + t.Run("verify blocking invalid algorithm", func(t *testing.T) { + t.Parallel() + + rsaCrypto, err := jwks.NewRSACrypto("", "R4ND0M", 2048) + require.NoError(t, err) + + authServer, err := jwks.NewServerWithCrypto(t, rsaCrypto) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + toJWKSConfig(authServer.JWKSURL(), time.Second*5), + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Manually craft a JWT with an unregistered/unknown alg value + hdr := map[string]any{"alg": "R4ND0M", "typ": "JWT", jwkset.HeaderKID: rsaCrypto.KID()} + pl := map[string]any{} + hBytes, err := json.Marshal(hdr) + require.NoError(t, err) + pBytes, err := json.Marshal(pl) + require.NoError(t, err) + signed := base64.RawURLEncoding.EncodeToString(hBytes) + "." + base64.RawURLEncoding.EncodeToString(pBytes) + ".bogus" + + header := http.Header{"Authorization": []string{"Bearer " + signed}} + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) } func toJWKSConfig(url string, refresh time.Duration, allowedAlgorithms ...string) authentication.JWKSConfig { From 6c6b1308b4dcaa3428e2acf2fb3d25822dc4ac3f Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 20:52:27 +0530 Subject: [PATCH 35/45] fix: tests --- router-tests/authentication_test.go | 111 ---------------------------- 1 file changed, 111 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index db7e01237e..61d3286b22 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -2362,117 +2362,6 @@ func TestSupportedAlgorithms(t *testing.T) { }) } -func TestJWKSIgnoreUnsupportedKeys(t *testing.T) { - t.Parallel() - - // Create one supported RSA key and one unsupported OKP(Ed448) entry in the same JWKS. - rsaProvider, err := jwks.NewRSACrypto("", jwkset.AlgRS256, 2048) - require.NoError(t, err) - - edProvider, err := jwks.NewED25519Crypto("") - require.NoError(t, err) - - rsaKID := rsaProvider.KID() - unsupportedKID := edProvider.KID() + "-unsupported" - - // Build RSA public JWK - rsaPriv := rsaProvider.PrivateKey().(*rsa.PrivateKey) - rsaPub := rsaPriv.PublicKey - nB64 := base64.RawURLEncoding.EncodeToString(rsaPub.N.Bytes()) - eB64 := base64.RawURLEncoding.EncodeToString(bigIntBytes(int64(rsaPub.E))) - - rsaJWK := map[string]any{ - "kty": "RSA", - "n": nB64, - "e": eB64, - "use": "sig", - "alg": jwkset.AlgRS256.String(), - "kid": rsaKID, - } - - // Build unsupported OKP Ed448 JWK (intentionally unsupported curve) - edPriv := edProvider.PrivateKey().(ed25519.PrivateKey) - edPub := edPriv.Public().(ed25519.PublicKey) - xB64 := base64.RawURLEncoding.EncodeToString(edPub) - unsupportedJWK := map[string]any{ - "kty": "OKP", - "crv": "Ed448", - "x": xB64, - "use": "sig", - "alg": jwkset.AlgEdDSA.String(), - "kid": unsupportedKID, - } - - jwksJSON := map[string]any{ - "keys": []any{rsaJWK, unsupportedJWK}, - } - - mux := http.NewServeMux() - mux.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(jwksJSON) - }) - ts := httptest.NewServer(mux) - t.Cleanup(ts.Close) - - authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ - { - URL: ts.URL + "/.well-known/jwks.json", - RefreshInterval: time.Second * 5, - }, - }) - - testenv.Run(t, &testenv.Config{ - RouterOptions: []core.Option{ - core.WithAccessController(core.NewAccessController(authenticators, true)), - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - // 1) Supported RSA token should succeed - rsaToken := jwt.New(jwt.SigningMethodRS256) - rsaToken.Header[jwkset.HeaderKID] = rsaKID - signedRSA, err := rsaToken.SignedString(rsaPriv) - require.NoError(t, err) - - header := http.Header{"Authorization": []string{"Bearer " + signedRSA}} - res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) - require.NoError(t, err) - defer func() { _ = res.Body.Close() }() - require.Equal(t, http.StatusOK, res.StatusCode) - require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) - body, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.Equal(t, employeesExpectedData, string(body)) - - // 2) Token with KID pointing to an unsupported Ed448 key in JWKS should fail - edToken := jwt.New(jwt.SigningMethodEdDSA) - edToken.Header[jwkset.HeaderKID] = unsupportedKID - signedEd, err := edToken.SignedString(edPriv) - require.NoError(t, err) - - header2 := http.Header{"Authorization": []string{"Bearer " + signedEd}} - res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header2, strings.NewReader(employeesQuery)) - require.NoError(t, err) - defer func() { _ = res2.Body.Close() }() - require.Equal(t, http.StatusUnauthorized, res2.StatusCode) - body2, err := io.ReadAll(res2.Body) - require.NoError(t, err) - require.JSONEq(t, unauthorizedExpectedData, string(body2)) - }) -} - -// bigIntBytes converts an int64 to big-endian bytes without leading zeros trimmed by base64 encoder behavior. -func bigIntBytes(v int64) []byte { - // Minimal big-endian bytes for the exponent - b := make([]byte, 0, 8) - for v > 0 { - b = append([]byte{byte(v & 0xff)}, b...) - v >>= 8 - } - if len(b) == 0 { - return []byte{0} - } - return b -} - func TestAuthenticationOverWebsocket(t *testing.T) { t.Parallel() From 04e1449ea02203e27d9d876b1b3e75d4b7accd71 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Mon, 22 Sep 2025 21:06:55 +0530 Subject: [PATCH 36/45] fix: schema --- router/pkg/config/config.schema.json | 3 +++ 1 file changed, 3 insertions(+) diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 570d5f8086..9192bb92d9 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1781,6 +1781,9 @@ }, { "required": ["refresh_interval"] + }, + { + "required": ["refresh_unknown_kid"] } ] } From 2ff8405a21497e6acc14dde36bc7da35cb3ce537 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Tue, 23 Sep 2025 13:50:41 +0530 Subject: [PATCH 37/45] fix: add tests --- router-tests/authentication_test.go | 56 ++++++++++++++++++- .../pkg/authentication/jwks_token_decoder.go | 12 ++-- 2 files changed, 60 insertions(+), 8 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 61d3286b22..db078a6f94 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -2,7 +2,6 @@ package integration import ( "bytes" - "crypto/ed25519" "crypto/rsa" "crypto/x509" "encoding/base64" @@ -10,7 +9,6 @@ import ( "encoding/pem" "io" "net/http" - "net/http/httptest" "strings" "sync" "sync/atomic" @@ -2793,7 +2791,61 @@ func TestAudienceValidation(t *testing.T) { require.NoError(t, err) require.JSONEq(t, unauthorizedExpectedData, string(data)) }) + }) + + t.Run("audience validation succeeds even when one audience match fails", func(t *testing.T) { + t.Parallel() + + tokenAudiences := []string{"aud1"} + + authServer1, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer1.Close) + + authServer2, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer2.Close) + token, err := authServer1.Token(map[string]any{"aud": tokenAudiences}) + require.NoError(t, err) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer2.JWKSURL(), + RefreshInterval: time.Second * 5, + Audiences: []string{"aud2"}, + }, + { + URL: authServer1.JWKSURL(), + RefreshInterval: time.Second * 5, + Audiences: []string{"aud1", "aud5"}, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Since the order is random from a map we run multiple times to ensure + // that at least the correct order was hit + for range 10 { + func() { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) + }() + } + }) }) t.Run("audience validation is ignored when expected aud is not provided", func(t *testing.T) { diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index f472577ea1..a99aa0b3b3 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -174,12 +174,6 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS keyFuncWrapper := jwt.Keyfunc(func(token *jwt.Token) (any, error) { var errJoin error for key, keyFunc := range keyFuncMap { - pub, err := keyFunc.Keyfunc(token) - if err != nil { - errJoin = errors.Join(errJoin, err) - continue - } - expectedAudiences := audiencesMap[key] if len(expectedAudiences) > 0 { tokenAudiences, err := token.Claims.GetAudience() @@ -192,6 +186,12 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS continue } } + + pub, err := keyFunc.Keyfunc(token) + if err != nil { + errJoin = errors.Join(errJoin, err) + continue + } return pub, nil } From 8d767c4429d4c43ffdd223ffada8be8a0ef7a5d0 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Tue, 23 Sep 2025 18:46:20 +0530 Subject: [PATCH 38/45] fix: tests --- router-tests/authentication_test.go | 28 ++++++-------- .../pkg/authentication/jwks_token_decoder.go | 38 ++++++++++++------- 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index db078a6f94..b04019cbc3 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -2827,24 +2827,18 @@ func TestAudienceValidation(t *testing.T) { core.WithAccessController(core.NewAccessController(authenticators, true)), }, }, func(t *testing.T, xEnv *testenv.Environment) { - // Since the order is random from a map we run multiple times to ensure - // that at least the correct order was hit - for range 10 { - func() { - // Operations with a token should succeed - header := http.Header{ - "Authorization": []string{"Bearer " + token}, - } - res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) - require.NoError(t, err) - defer res.Body.Close() - require.Equal(t, http.StatusOK, res.StatusCode) - require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) - data, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.Equal(t, employeesExpectedData, string(data)) - }() + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) }) }) diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index a99aa0b3b3..ec8cd0d6ab 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -4,10 +4,11 @@ import ( "context" "errors" "fmt" - "github.com/wundergraph/keyfunc/v3" "net/http" "time" + "github.com/wundergraph/keyfunc/v3" + "golang.org/x/time/rate" "github.com/MicahParks/jwkset" @@ -60,20 +61,26 @@ type RefreshUnknownKIDConfig struct { MaxWait time.Duration } -type audKey struct { +type configKey struct { kid string url string } type audienceSet map[string]struct{} +type keyFuncEntry struct { + jwks keyfunc.Keyfunc + aud audienceSet +} + func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKSConfig) (TokenDecoder, error) { - audiencesMap := make(map[audKey]audienceSet, len(configs)) - keyFuncMap := make(map[audKey]keyfunc.Keyfunc, len(configs)) + // Audience map is used to validate duplicate configs + audiencesMap := make(map[configKey]audienceSet, len(configs)) + entries := make([]keyFuncEntry, 0, len(configs)) for _, c := range configs { if c.URL != "" { - key := audKey{url: c.URL} + key := configKey{url: c.URL} if _, ok := audiencesMap[key]; ok { return nil, fmt.Errorf("duplicate JWK URL found: %s", c.URL) } @@ -117,10 +124,13 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS if err != nil { return nil, err } - keyFuncMap[key] = jwks + entries = append(entries, keyFuncEntry{ + jwks: jwks, + aud: audiencesMap[key], + }) } else if c.Secret != "" { - key := audKey{kid: c.KeyId} + key := configKey{kid: c.KeyId} if _, ok := audiencesMap[key]; ok { return nil, fmt.Errorf("duplicate JWK keyid specified found: %s", c.KeyId) } @@ -167,27 +177,29 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS if err != nil { return nil, err } - keyFuncMap[key] = jwks + entries = append(entries, keyFuncEntry{ + jwks: jwks, + aud: audiencesMap[key], + }) } } keyFuncWrapper := jwt.Keyfunc(func(token *jwt.Token) (any, error) { var errJoin error - for key, keyFunc := range keyFuncMap { - expectedAudiences := audiencesMap[key] - if len(expectedAudiences) > 0 { + for _, entry := range entries { + if len(entry.aud) > 0 { tokenAudiences, err := token.Claims.GetAudience() if err != nil { errJoin = errors.Join(errJoin, fmt.Errorf("could not get audiences from token claims: %w", err)) continue } - if !hasAudience(tokenAudiences, expectedAudiences) { + if !hasAudience(tokenAudiences, entry.aud) { errJoin = errors.Join(errJoin, errUnacceptableAud) continue } } - pub, err := keyFunc.Keyfunc(token) + pub, err := entry.jwks.Keyfunc(token) if err != nil { errJoin = errors.Join(errJoin, err) continue From 9cc810a3dd068b97b0a523f737e4ddad979fe2be Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Tue, 30 Sep 2025 15:33:35 +0530 Subject: [PATCH 39/45] fix: updates --- router-tests/go.mod | 2 +- router-tests/go.sum | 4 +- router/go.mod | 2 +- router/go.sum | 4 +- .../pkg/authentication/jwks_token_decoder.go | 47 ++++++++++++++----- 5 files changed, 41 insertions(+), 18 deletions(-) diff --git a/router-tests/go.mod b/router-tests/go.mod index 9fceb9f617..7b3061f68c 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -45,6 +45,7 @@ require ( connectrpc.com/connect v1.16.2 // indirect github.com/99designs/gqlgen v0.17.76 // indirect github.com/KimMachineGun/automemlimit v0.6.1 // indirect + github.com/MicahParks/keyfunc/v3 v3.3.5 // indirect github.com/agnivade/levenshtein v1.2.1 // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect @@ -150,7 +151,6 @@ require ( github.com/vbatts/tar-split v0.12.1 // indirect github.com/vektah/gqlparser/v2 v2.5.30 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect - github.com/wundergraph/keyfunc/v3 v3.0.0-20250922133930-92f21becf3d9 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect diff --git a/router-tests/go.sum b/router-tests/go.sum index 0b8d19ce7f..8ba0cbab58 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -7,6 +7,8 @@ github.com/KimMachineGun/automemlimit v0.6.1 h1:ILa9j1onAAMadBsyyUJv5cack8Y1WT26 github.com/KimMachineGun/automemlimit v0.6.1/go.mod h1:T7xYht7B8r6AG/AqFcUdc7fzd2bIdBKmepfP2S1svPY= github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= +github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo= +github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8= github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo= github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y= github.com/agnivade/levenshtein v1.2.1 h1:EHBY3UOn1gwdy/VbFwgo4cxecRznFk7fKWN1KOX7eoM= @@ -352,8 +354,6 @@ github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301 h1:EzfKHQoT github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301/go.mod h1:wxI0Nak5dI5RvJuzGyiEK4nZj0O9X+Aw6U0tC1wPKq0= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.226 h1:3g6KNCG4ydgnpZnIlCK7pmtv0FSge6ILUS5LjrNZNiI= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.226/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= -github.com/wundergraph/keyfunc/v3 v3.0.0-20250922133930-92f21becf3d9 h1:7bPpsPUUxy5dEnuDSy2q3PAmflxqKx9vnyaTj3TSMBo= -github.com/wundergraph/keyfunc/v3 v3.0.0-20250922133930-92f21becf3d9/go.mod h1:el0U1ewqJ/T/Urlt3wImfmuBmoQdjL5yoNQ5e/+O98M= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/router/go.mod b/router/go.mod index 2fd3bb1abf..1f4f7e5e55 100644 --- a/router/go.mod +++ b/router/go.mod @@ -59,6 +59,7 @@ require ( require ( github.com/KimMachineGun/automemlimit v0.6.1 github.com/MicahParks/jwkset v0.11.0 + github.com/MicahParks/keyfunc/v3 v3.3.5 github.com/alicebob/miniredis/v2 v2.34.0 github.com/caarlos0/env/v11 v11.3.1 github.com/cep21/circuit/v4 v4.0.0 @@ -78,7 +79,6 @@ require ( github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 github.com/tonglil/opentelemetry-go-datadog-propagator v0.1.3 github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 - github.com/wundergraph/keyfunc/v3 v3.0.0-20250922133930-92f21becf3d9 go.uber.org/goleak v1.3.0 go.uber.org/ratelimit v0.3.1 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 diff --git a/router/go.sum b/router/go.sum index 1d4cb2afbd..bea50f31ca 100644 --- a/router/go.sum +++ b/router/go.sum @@ -7,6 +7,8 @@ github.com/KimMachineGun/automemlimit v0.6.1 h1:ILa9j1onAAMadBsyyUJv5cack8Y1WT26 github.com/KimMachineGun/automemlimit v0.6.1/go.mod h1:T7xYht7B8r6AG/AqFcUdc7fzd2bIdBKmepfP2S1svPY= github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= +github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo= +github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8= github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE= @@ -317,8 +319,6 @@ github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTB github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.226 h1:3g6KNCG4ydgnpZnIlCK7pmtv0FSge6ILUS5LjrNZNiI= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.226/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= -github.com/wundergraph/keyfunc/v3 v3.0.0-20250922133930-92f21becf3d9 h1:7bPpsPUUxy5dEnuDSy2q3PAmflxqKx9vnyaTj3TSMBo= -github.com/wundergraph/keyfunc/v3 v3.0.0-20250922133930-92f21becf3d9/go.mod h1:el0U1ewqJ/T/Urlt3wImfmuBmoQdjL5yoNQ5e/+O98M= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index ec8cd0d6ab..278d4f0b8e 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -5,9 +5,10 @@ import ( "errors" "fmt" "net/http" + "slices" "time" - "github.com/wundergraph/keyfunc/v3" + "github.com/MicahParks/keyfunc/v3" "golang.org/x/time/rate" @@ -69,8 +70,9 @@ type configKey struct { type audienceSet map[string]struct{} type keyFuncEntry struct { - jwks keyfunc.Keyfunc - aud audienceSet + jwks keyfunc.Keyfunc + aud audienceSet + allowedAlgorithms []string } func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKSConfig) (TokenDecoder, error) { @@ -120,13 +122,14 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS jwksetHTTPClientOptions.RateLimitWaitMax = c.RefreshUnknownKID.MaxWait } - jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions, c.AllowedAlgorithms) + jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions) if err != nil { return nil, err } entries = append(entries, keyFuncEntry{ - jwks: jwks, - aud: audiencesMap[key], + jwks: jwks, + aud: audiencesMap[key], + allowedAlgorithms: c.AllowedAlgorithms, }) } else if c.Secret != "" { @@ -173,7 +176,7 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS PrioritizeHTTP: false, } - jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions, make([]string, 0)) + jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions) if err != nil { return nil, err } @@ -199,6 +202,27 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS } } + // When an algorithm is actually provided in the jwks the current keyfunc will validate the + // jwks algorithm with it. But when no algorithm is provided (alg: none or missing alg) + // the default keyfunc will not validate the algorithm as it has nothing to cross check. + if len(entry.allowedAlgorithms) > 0 { + algInter, ok := token.Header["alg"] + if !ok { + return nil, fmt.Errorf("%w: could not find alg in JWT header", keyfunc.ErrKeyfunc) + } + alg, ok := algInter.(string) + if !ok { + // For test coverage purposes, this should be impossible to reach because the JWT package rejects a token + // without an alg parameter in the header before calling jwt.Keyfunc. + return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, keyfunc.ErrKeyfunc) + } + + // This is a custom validation different from the original keyfunc.Keyfunc + if !slices.Contains(entry.allowedAlgorithms, alg) { + return nil, fmt.Errorf("%w: could not find alg %s in allow list", keyfunc.ErrKeyfunc, alg) + } + } + pub, err := entry.jwks.Keyfunc(token) if err != nil { errJoin = errors.Join(errJoin, err) @@ -223,17 +247,16 @@ func getAudienceSet(audiences []string) audienceSet { return audSet } -func createKeyFunc(ctx context.Context, options jwkset.HTTPClientOptions, algorithms []string) (keyfunc.Keyfunc, error) { +func createKeyFunc(ctx context.Context, options jwkset.HTTPClientOptions) (keyfunc.Keyfunc, error) { combined, err := jwkset.NewHTTPClient(options) if err != nil { return nil, fmt.Errorf("failed to create HTTP client storage for JWK provider: %w", err) } keyfuncOptions := keyfunc.Options{ - Ctx: ctx, - Storage: combined, - UseWhitelist: []jwkset.USE{jwkset.UseSig}, - AllowedAlgorithms: algorithms, + Ctx: ctx, + Storage: combined, + UseWhitelist: []jwkset.USE{jwkset.UseSig}, } jwks, err := keyfunc.New(keyfuncOptions) From a2d6baebe0d3a3510d55c1069c5cf2a4b296b381 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 1 Oct 2025 14:59:06 +0530 Subject: [PATCH 40/45] fix: comments --- router/pkg/authentication/jwks_token_decoder.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index 25e63efa51..e876f4e097 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -213,8 +213,6 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS } alg, ok := algInter.(string) if !ok { - // For test coverage purposes, this should be impossible to reach because the JWT package rejects a token - // without an alg parameter in the header before calling jwt.Keyfunc. return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, keyfunc.ErrKeyfunc) } From 25bad88dc623cf125df13bd081b4154345610d84 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 1 Oct 2025 15:17:10 +0530 Subject: [PATCH 41/45] fix: using continue --- router-tests/authentication_test.go | 51 +++++++++++++++++++ .../pkg/authentication/jwks_token_decoder.go | 9 ++-- 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index b04019cbc3..8bc00e7d72 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -3001,6 +3001,57 @@ func TestAudienceValidation(t *testing.T) { require.JSONEq(t, unauthorizedExpectedData, string(data)) }) }) + + t.Run("valid token for second entry with empty algorithm in JWKS", func(t *testing.T) { + t.Parallel() + + rsaCrypto, err := jwks.NewRSACrypto("", "", 2048) + require.NoError(t, err) + + authServer1, err := jwks.NewServerWithCrypto(t, rsaCrypto) + require.NoError(t, err) + t.Cleanup(authServer1.Close) + + authServer2, err := jwks.NewServerWithCrypto(t, rsaCrypto) + require.NoError(t, err) + t.Cleanup(authServer1.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer1.JWKSURL(), + RefreshInterval: time.Second * 5, + AllowedAlgorithms: []string{string(jwkset.AlgRS256)}, + }, + { + URL: authServer2.JWKSURL(), + RefreshInterval: time.Second * 5, + AllowedAlgorithms: []string{string(jwkset.AlgRS512)}, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, false)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + token, err := authServer2.TokenWithOpts(nil, jwks.TokenOpts{ + AlgOverride: string(jwkset.AlgRS512), + }) + require.NoError(t, err) + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) + }) + }) } func toJWKSConfig(url string, refresh time.Duration, allowedAlgorithms ...string) authentication.JWKSConfig { diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index e876f4e097..015c588640 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -209,16 +209,19 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS if len(entry.allowedAlgorithms) > 0 { algInter, ok := token.Header["alg"] if !ok { - return nil, fmt.Errorf("%w: could not find alg in JWT header", keyfunc.ErrKeyfunc) + errJoin = errors.Join(errJoin, fmt.Errorf("%w: could not find alg in JWT header", keyfunc.ErrKeyfunc)) + continue } alg, ok := algInter.(string) if !ok { - return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, keyfunc.ErrKeyfunc) + errJoin = errors.Join(errJoin, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, keyfunc.ErrKeyfunc)) + continue } // This is a custom validation different from the original keyfunc.Keyfunc if !slices.Contains(entry.allowedAlgorithms, alg) { - return nil, fmt.Errorf("%w: could not find alg %s in allow list", keyfunc.ErrKeyfunc, alg) + errJoin = errors.Join(errJoin, fmt.Errorf("%w: could not find alg %s in allow list", keyfunc.ErrKeyfunc, alg)) + continue } } From 0c3286fd3cde217303fc7b729130d47c1e45d10b Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 1 Oct 2025 15:26:23 +0530 Subject: [PATCH 42/45] fix: tests --- router-tests/authentication_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 8bc00e7d72..5dc0207524 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -3014,7 +3014,7 @@ func TestAudienceValidation(t *testing.T) { authServer2, err := jwks.NewServerWithCrypto(t, rsaCrypto) require.NoError(t, err) - t.Cleanup(authServer1.Close) + t.Cleanup(authServer2.Close) authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ { From 113a61f464e983ded8cb9580e871ee54060292c4 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 1 Oct 2025 15:44:01 +0530 Subject: [PATCH 43/45] fix: tests --- router-tests/authentication_test.go | 124 +++++++++++++++++++--------- 1 file changed, 87 insertions(+), 37 deletions(-) diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 5dc0207524..68a44e55c0 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -2796,49 +2796,99 @@ func TestAudienceValidation(t *testing.T) { t.Run("audience validation succeeds even when one audience match fails", func(t *testing.T) { t.Parallel() - tokenAudiences := []string{"aud1"} + t.Run("with http based configuration", func(t *testing.T) { + t.Parallel() - authServer1, err := jwks.NewServer(t) - require.NoError(t, err) - t.Cleanup(authServer1.Close) + tokenAudiences := []string{"aud1"} - authServer2, err := jwks.NewServer(t) - require.NoError(t, err) - t.Cleanup(authServer2.Close) + authServer1, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer1.Close) - token, err := authServer1.Token(map[string]any{"aud": tokenAudiences}) - require.NoError(t, err) + authServer2, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer2.Close) - authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ - { - URL: authServer2.JWKSURL(), - RefreshInterval: time.Second * 5, - Audiences: []string{"aud2"}, - }, - { - URL: authServer1.JWKSURL(), - RefreshInterval: time.Second * 5, - Audiences: []string{"aud1", "aud5"}, - }, + token, err := authServer1.Token(map[string]any{"aud": tokenAudiences}) + require.NoError(t, err) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer2.JWKSURL(), + RefreshInterval: time.Second * 5, + Audiences: []string{"aud2"}, + }, + { + URL: authServer1.JWKSURL(), + RefreshInterval: time.Second * 5, + Audiences: []string{"aud1", "aud5"}, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) + }) }) - testenv.Run(t, &testenv.Config{ - RouterOptions: []core.Option{ - core.WithAccessController(core.NewAccessController(authenticators, true)), - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - // Operations with a token should succeed - header := http.Header{ - "Authorization": []string{"Bearer " + token}, - } - res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) - require.NoError(t, err) - defer res.Body.Close() - require.Equal(t, http.StatusOK, res.StatusCode) - require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) - data, err := io.ReadAll(res.Body) - require.NoError(t, err) - require.Equal(t, employeesExpectedData, string(data)) + t.Run("with secret based configuration", func(t *testing.T) { + t.Parallel() + + matchingAud := "matchingAudience" + + secret := "example secret" + kid := "givenKID" + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + Secret: "secret", + Algorithm: string(jwkset.AlgHS256), + KeyId: "kid", + Audiences: []string{"aud3"}, + }, + { + Secret: secret, + Algorithm: string(jwkset.AlgHS256), + KeyId: kid, + Audiences: []string{matchingAud, "aud5"}, + }, + }) + + token := generateToken(t, kid, secret, jwt.SigningMethodHS256, jwt.MapClaims{ + "aud": matchingAud, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) + }) }) }) From 59323d99f18fd963f0f347ec87cd04412f8c5773 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 1 Oct 2025 15:52:01 +0530 Subject: [PATCH 44/45] fix: go.mod updates --- router-tests/go.mod | 2 +- router-tests/go.sum | 4 ++-- router/go.mod | 2 +- router/go.sum | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/router-tests/go.mod b/router-tests/go.mod index a84ab7eb81..b49dec69d0 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -45,7 +45,7 @@ require ( connectrpc.com/connect v1.16.2 // indirect github.com/99designs/gqlgen v0.17.76 // indirect github.com/KimMachineGun/automemlimit v0.6.1 // indirect - github.com/MicahParks/keyfunc/v3 v3.3.5 // indirect + github.com/MicahParks/keyfunc/v3 v3.6.2 // indirect github.com/agnivade/levenshtein v1.2.1 // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect diff --git a/router-tests/go.sum b/router-tests/go.sum index a33cb64bbb..65eaff9693 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -7,8 +7,8 @@ github.com/KimMachineGun/automemlimit v0.6.1 h1:ILa9j1onAAMadBsyyUJv5cack8Y1WT26 github.com/KimMachineGun/automemlimit v0.6.1/go.mod h1:T7xYht7B8r6AG/AqFcUdc7fzd2bIdBKmepfP2S1svPY= github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= -github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo= -github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8= +github.com/MicahParks/keyfunc/v3 v3.6.2 h1:82rre60MKw4r117ew5/T4m1AphgkpCOYry0RPbFUY3w= +github.com/MicahParks/keyfunc/v3 v3.6.2/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0= github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo= github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y= github.com/agnivade/levenshtein v1.2.1 h1:EHBY3UOn1gwdy/VbFwgo4cxecRznFk7fKWN1KOX7eoM= diff --git a/router/go.mod b/router/go.mod index 0497c14b7b..3bf34dfab4 100644 --- a/router/go.mod +++ b/router/go.mod @@ -59,7 +59,7 @@ require ( require ( github.com/KimMachineGun/automemlimit v0.6.1 github.com/MicahParks/jwkset v0.11.0 - github.com/MicahParks/keyfunc/v3 v3.3.5 + github.com/MicahParks/keyfunc/v3 v3.6.2 github.com/alicebob/miniredis/v2 v2.34.0 github.com/caarlos0/env/v11 v11.3.1 github.com/cep21/circuit/v4 v4.0.0 diff --git a/router/go.sum b/router/go.sum index dad2da1a19..0bcf0aa4f0 100644 --- a/router/go.sum +++ b/router/go.sum @@ -7,8 +7,8 @@ github.com/KimMachineGun/automemlimit v0.6.1 h1:ILa9j1onAAMadBsyyUJv5cack8Y1WT26 github.com/KimMachineGun/automemlimit v0.6.1/go.mod h1:T7xYht7B8r6AG/AqFcUdc7fzd2bIdBKmepfP2S1svPY= github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= -github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo= -github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8= +github.com/MicahParks/keyfunc/v3 v3.6.2 h1:82rre60MKw4r117ew5/T4m1AphgkpCOYry0RPbFUY3w= +github.com/MicahParks/keyfunc/v3 v3.6.2/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0= github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE= From 54ad4f324b16142258f8d433ffb84b6c87e66db2 Mon Sep 17 00:00:00 2001 From: Milinda Dias Date: Wed, 1 Oct 2025 15:54:58 +0530 Subject: [PATCH 45/45] fix: imports --- router/pkg/authentication/jwks_token_decoder.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index 015c588640..1685c2285a 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -8,11 +8,10 @@ import ( "slices" "time" - "github.com/MicahParks/keyfunc/v3" - "golang.org/x/time/rate" "github.com/MicahParks/jwkset" + "github.com/MicahParks/keyfunc/v3" "github.com/golang-jwt/jwt/v5" "github.com/wundergraph/cosmo/router/internal/httpclient"