diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 38d2715a56..326b84749a 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -8,6 +8,8 @@ import ( "io" "net/http" "strings" + "sync" + "sync/atomic" "testing" "time" @@ -56,6 +58,333 @@ func TestAuthentication(t *testing.T) { }) }) + t.Run("unknown kid refresh blocks when burst exceeded", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, + Burst: 1, + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + header := http.Header{"Authorization": []string{"Bearer " + token}} + + res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res1.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res1.StatusCode) + _, err = io.ReadAll(res1.Body) + require.NoError(t, err) + + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + elapsed := time.Since(start) + + require.True(t, elapsed >= 700*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + + t.Run("unknown kid refresh does not block when burst not exceeded", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, + Burst: 1, + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + header := http.Header{"Authorization": []string{"Bearer " + token}} + + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + _, err = io.ReadAll(res.Body) + require.NoError(t, err) + + // Wait for interval so next refresh is within burst budget + time.Sleep(1200 * time.Millisecond) + + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + elapsed := time.Since(start) + require.True(t, elapsed < 100*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + + // Since the rate limiter knows that the limit will definitely be exceeded it exits + // immediately without waiting + t.Run("unknown kid refresh interval exceeding max wait returns immediately", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, // next token available in ~1s + Burst: 1, + MaxWait: 700 * time.Millisecond, // cap wait well below interval + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + header := http.Header{"Authorization": []string{"Bearer " + token}} + + res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res1.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res1.StatusCode) + _, err = io.ReadAll(res1.Body) + require.NoError(t, err) + + // Next call should exceed max wait so should return immediately + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + elapsed := time.Since(start) + require.True(t, elapsed < 100*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + + t.Run("unknown kid refresh exceeding burst waits until interval when max wait larger", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, + Burst: 1, + MaxWait: 2 * time.Second, // larger than interval, so it can wait until next token + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + header := http.Header{"Authorization": []string{"Bearer " + token}} + + res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res1.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res1.StatusCode) + _, err = io.ReadAll(res1.Body) + require.NoError(t, err) + + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + elapsed := time.Since(start) + + require.True(t, elapsed >= 600*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + + // After consuming the single burst token, launch multiple requests in parallel. + // Each should block if the max limit has not been accumulated + t.Run("unknown kid refresh parallel exceeding burst waits up to max wait", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + const waitEntries = 4 + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 10 * time.Second, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: true, + Interval: 1 * time.Second, + Burst: 1, + MaxWait: waitEntries * time.Second, + }, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + header := http.Header{"Authorization": []string{"Bearer " + token}} + + // Send initial request to use up the burst token + res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res1.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res1.StatusCode) + _, err = io.ReadAll(res1.Body) + require.NoError(t, err) + + var elapsedFastCounter atomic.Int64 + var wg sync.WaitGroup + + for range waitEntries + 1 { + wg.Add(1) + + go func() { + defer wg.Done() + + start := time.Now() + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res2.Body.Close() }() + + elapsed := time.Since(start) + + if elapsed < 100*time.Millisecond { + elapsedFastCounter.Add(1) + } + + require.True(t, elapsed < 50*time.Millisecond || elapsed >= 700*time.Millisecond) + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + data, err := io.ReadAll(res2.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }() + } + + wg.Wait() + + // We only exit early on the 5th request as by the 5th request we have accumulated + // enough tokens to exceed the max wait duration + require.Equal(t, 1, int(elapsedFastCounter.Load())) + }) + }) + + t.Run("authentication should not block with unknown kid when refresh is disabled", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 100 * time.Millisecond, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Create a token signed with a valid key but with an unknown kid header + token, err := authServer.TokenForKID("unknown_kid", nil, true) + require.NoError(t, err) + + maxDuration := 4 * time.Second + testenv.AwaitFunc(t, maxDuration, func() { + for range 5 { + func() { + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }() + } + }) + }) + }) + t.Run("invalid token", func(t *testing.T) { t.Parallel() @@ -1182,7 +1511,7 @@ func TestAlgorithmMismatch(t *testing.T) { authenticators := []authentication.Authenticator{authenticator} - token, err := authServer.TokenForKID(crypto.KID(), nil) + token, err := authServer.TokenForKID(crypto.KID(), nil, false) require.NoError(t, err) return token, authenticators @@ -1307,7 +1636,7 @@ func TestOidcDiscovery(t *testing.T) { tokens := make(map[string]string) for _, c := range crypto { - token, err := authServer.TokenForKID(c.KID(), nil) + token, err := authServer.TokenForKID(c.KID(), nil, false) require.NoError(t, err) tokens[c.KID()] = token @@ -1421,7 +1750,7 @@ func TestMultipleKeys(t *testing.T) { tokens := make(map[string]string) for _, c := range crypto { - token, err := authServer.TokenForKID(c.KID(), nil) + token, err := authServer.TokenForKID(c.KID(), nil, false) require.NoError(t, err) tokens[c.KID()] = token @@ -1604,7 +1933,7 @@ func TestSupportedAlgorithms(t *testing.T) { authenticators := []authentication.Authenticator{authenticator} - token, err := authServer.TokenForKID(crypto.KID(), nil) + token, err := authServer.TokenForKID(crypto.KID(), nil, false) require.NoError(t, err) return token, authenticators diff --git a/router-tests/header_set_test.go b/router-tests/header_set_test.go index 763663e2b5..a8f15d69af 100644 --- a/router-tests/header_set_test.go +++ b/router-tests/header_set_test.go @@ -275,7 +275,7 @@ func TestHeaderSetWithExpression(t *testing.T) { authenticator, err := authentication.NewHttpHeaderAuthenticator(authOptions) require.NoError(t, err) - token, err := authServer.TokenForKID(rsa1.KID(), map[string]any{"user_id": "TestId"}) + token, err := authServer.TokenForKID(rsa1.KID(), map[string]any{"user_id": "TestId"}, false) require.NoError(t, err) testenv.Run(t, &testenv.Config{ diff --git a/router-tests/jwks/jwks.go b/router-tests/jwks/jwks.go index 080879aa0c..6b77a76812 100644 --- a/router-tests/jwks/jwks.go +++ b/router-tests/jwks/jwks.go @@ -50,11 +50,18 @@ func (s *Server) Token(claims map[string]any) (string, error) { return "", jwt.ErrInvalidKey } -func (s *Server) TokenForKID(kid string, claims map[string]any) (string, error) { +func (s *Server) TokenForKID(kid string, claims map[string]any, useInvalidKID bool) (string, error) { provider, ok := s.providers[kid] - if !ok { + if useInvalidKID { + // If we don't care about the kid, use any available provider + for _, pr := range s.providers { + provider = pr + break + } + } else if !ok { return "", jwt.ErrInvalidKey } + token := jwt.NewWithClaims(provider.SigningMethod(), jwt.MapClaims(claims)) token.Header[jwkset.HeaderKID] = kid return token.SignedString(provider.PrivateKey()) diff --git a/router-tests/testenv/utils.go b/router-tests/testenv/utils.go index bd4f1842db..19d13ba27c 100644 --- a/router-tests/testenv/utils.go +++ b/router-tests/testenv/utils.go @@ -28,3 +28,15 @@ func AwaitChannelWithCloseWithT[A any](t *testing.T, timeout time.Duration, ch < require.Fail(t, "unable to receive message before timeout", msgAndArgs...) } } + +func AwaitFunc(t *testing.T, timeout time.Duration, testFunction func()) { + t.Helper() + + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + testFunction() + }() + + AwaitChannelWithT(t, timeout, doneCh, func(t *testing.T, _ struct{}) {}, "the test function timed out") +} diff --git a/router/core/supervisor_instance.go b/router/core/supervisor_instance.go index 09605f6d9e..7a43d3f138 100644 --- a/router/core/supervisor_instance.go +++ b/router/core/supervisor_instance.go @@ -265,6 +265,12 @@ func setupAuthenticators(ctx context.Context, logger *zap.Logger, cfg *config.Co KeyId: jwks.KeyId, Audiences: jwks.Audiences, + RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{ + Enabled: jwks.RefreshUnknownKID.Enabled, + MaxWait: jwks.RefreshUnknownKID.MaxWait, + Interval: jwks.RefreshUnknownKID.Interval, + Burst: jwks.RefreshUnknownKID.Burst, + }, }) } diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index 53b252632c..c6ae6c794e 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -7,13 +7,13 @@ import ( "net/http" "time" + "golang.org/x/time/rate" + "github.com/MicahParks/jwkset" "github.com/MicahParks/keyfunc/v3" "github.com/golang-jwt/jwt/v5" - "go.uber.org/zap" - "golang.org/x/time/rate" - "github.com/wundergraph/cosmo/router/internal/httpclient" + "go.uber.org/zap" ) type TokenDecoder interface { @@ -49,6 +49,15 @@ type JWKSConfig struct { KeyId string Audiences []string + + RefreshUnknownKID RefreshUnknownKIDConfig +} + +type RefreshUnknownKIDConfig struct { + Enabled bool + Interval time.Duration + Burst int + MaxWait time.Duration } type audKey struct { @@ -95,8 +104,13 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS HTTPURLs: map[string]jwkset.Storage{ c.URL: store, }, - PrioritizeHTTP: true, - RefreshUnknownKID: rate.NewLimiter(rate.Every(5*time.Minute), 1), + PrioritizeHTTP: true, + } + + // Configure the rate limiter for refreshing unknown KIDs + if c.RefreshUnknownKID.Enabled { + jwksetHTTPClientOptions.RefreshUnknownKID = rate.NewLimiter(rate.Every(c.RefreshUnknownKID.Interval), c.RefreshUnknownKID.Burst) + jwksetHTTPClientOptions.RateLimitWaitMax = c.RefreshUnknownKID.MaxWait } jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions) diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index c33310d657..57591de31e 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -466,9 +466,10 @@ type OverridesConfiguration struct { } type JWKSConfiguration struct { - URL string `yaml:"url"` - Algorithms []string `yaml:"algorithms"` - RefreshInterval time.Duration `yaml:"refresh_interval" envDefault:"1m"` + URL string `yaml:"url"` + Algorithms []string `yaml:"algorithms"` + RefreshInterval time.Duration `yaml:"refresh_interval" envDefault:"1m"` + RefreshUnknownKID RefreshUnknownKID `yaml:"refresh_unknown_kid"` // For secret based where we need to create a jwk entry with // a key id and algorithm @@ -480,6 +481,13 @@ type JWKSConfiguration struct { Audiences []string `yaml:"audiences"` } +type RefreshUnknownKID struct { + Enabled bool `yaml:"enabled" envDefault:"false"` + MaxWait time.Duration `yaml:"max_wait" envDefault:"2m"` + Interval time.Duration `yaml:"interval" envDefault:"30s"` + Burst int `yaml:"burst" envDefault:"2"` +} + type HeaderSource struct { Type string `yaml:"type"` Name string `yaml:"name"` diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index c45890172e..570d5f8086 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1716,6 +1716,40 @@ }, "description": "The interval at which the JWKs are refreshed. The period is specified as a string with a number and a unit, e.g. 10ms, 1s, 1m, 1h. The supported units are 'ms', 's', 'm', 'h'.", "default": "1m" + }, + "refresh_unknown_kid": { + "type": "object", + "description": "Controls rate-limited refresh behavior when a JWT KID is unknown.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable refresh attempts on unknown KID.", + "default": false + }, + "max_wait": { + "type": "string", + "description": "Maximum time to wait for a refresh permit before giving up.", + "default": "10s", + "duration": { + "minimum": "0s" + } + }, + "interval": { + "type": "string", + "description": "Token refill interval for the rate limiter.", + "default": "1m", + "duration": { + "minimum": "1s" + } + }, + "burst": { + "type": "integer", + "description": "Burst size for the rate limiter.", + "default": 2, + "minimum": 1 + } + } } }, "oneOf": [ diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index afd160f911..a43691cc12 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -275,6 +275,11 @@ authentication: - url: 'https://example.com/.well-known/jwks2.json' refresh_interval: 2m algorithms: ['RS256', 'ES256'] + refresh_unknown_kid: + enabled: true + max_wait: 10s + interval: 5s + burst: 3 - url: 'https://example.com/.well-known/jwks3.json' header_name: Authorization header_value_prefix: Bearer diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index ec8b47da50..ba295c3935 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -480,6 +480,12 @@ "RS256" ], "RefreshInterval": 60000000000, + "RefreshUnknownKID": { + "Enabled": false, + "MaxWait": 0, + "Interval": 0, + "Burst": 0 + }, "Secret": "", "Algorithm": "", "KeyId": "", @@ -492,6 +498,12 @@ "ES256" ], "RefreshInterval": 120000000000, + "RefreshUnknownKID": { + "Enabled": true, + "MaxWait": 10000000000, + "Interval": 5000000000, + "Burst": 3 + }, "Secret": "", "Algorithm": "", "KeyId": "", @@ -501,6 +513,12 @@ "URL": "https://example.com/.well-known/jwks3.json", "Algorithms": null, "RefreshInterval": 0, + "RefreshUnknownKID": { + "Enabled": false, + "MaxWait": 0, + "Interval": 0, + "Burst": 0 + }, "Secret": "", "Algorithm": "", "KeyId": "",