Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 149 additions & 4 deletions router-tests/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,151 @@ func TestAuthentication(t *testing.T) {
})
})

t.Run("unknown kid refresh blocks when burst exceeded", func(t *testing.T) {
t.Parallel()

authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)

authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{
{
URL: authServer.JWKSURL(),
RefreshInterval: 10 * time.Second,
RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{
Enabled: true,
Interval: 1 * time.Second,
Burst: 1,
},
},
})

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(core.NewAccessController(authenticators, true)),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
token, err := authServer.TokenForKID("unknown_kid", nil, true)
require.NoError(t, err)

header := http.Header{"Authorization": []string{"Bearer " + token}}

res1, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res1.Body.Close() }()
require.Equal(t, http.StatusUnauthorized, res1.StatusCode)
_, err = io.ReadAll(res1.Body)
require.NoError(t, err)

start := time.Now()
res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res2.Body.Close() }()
elapsed := time.Since(start)

require.True(t, elapsed >= 600*time.Millisecond)
require.Equal(t, http.StatusUnauthorized, res2.StatusCode)
data, err := io.ReadAll(res2.Body)
require.NoError(t, err)
require.JSONEq(t, unauthorizedExpectedData, string(data))
})
})

t.Run("unknown kid refresh does not block when burst not exceeded", func(t *testing.T) {
t.Parallel()

authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)

authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{
{
URL: authServer.JWKSURL(),
RefreshInterval: 10 * time.Second,
RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{
Enabled: true,
Interval: 1 * time.Second,
Burst: 1,
},
},
})

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(core.NewAccessController(authenticators, true)),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
token, err := authServer.TokenForKID("unknown_kid", nil, true)
require.NoError(t, err)
header := http.Header{"Authorization": []string{"Bearer " + token}}

res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res.Body.Close() }()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
_, err = io.ReadAll(res.Body)
require.NoError(t, err)

// Wait for interval so next refresh is within burst budget
time.Sleep(1200 * time.Millisecond)

start := time.Now()
res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res2.Body.Close() }()
elapsed := time.Since(start)
require.True(t, elapsed < 100*time.Millisecond)
require.Equal(t, http.StatusUnauthorized, res2.StatusCode)
data, err := io.ReadAll(res2.Body)
require.NoError(t, err)
require.JSONEq(t, unauthorizedExpectedData, string(data))
})
})

t.Run("authentication should not block with unknown kid when refresh is disabled", func(t *testing.T) {
t.Parallel()

authServer, err := jwks.NewServer(t)
require.NoError(t, err)
t.Cleanup(authServer.Close)

authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{
{
URL: authServer.JWKSURL(),
RefreshInterval: 100 * time.Millisecond,
},
})

testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(core.NewAccessController(authenticators, true)),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
// Create a token signed with a valid key but with an unknown kid header
token, err := authServer.TokenForKID("unknown_kid", nil, true)
require.NoError(t, err)

maxDuration := 4 * time.Second
testenv.AwaitFunc(t, maxDuration, func() {
for range 5 {
func() {
header := http.Header{
"Authorization": []string{"Bearer " + token},
}
res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() { _ = res.Body.Close() }()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader))
data, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.JSONEq(t, unauthorizedExpectedData, string(data))
}()
}
})
})
})

t.Run("invalid token", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -1182,7 +1327,7 @@ func TestAlgorithmMismatch(t *testing.T) {

authenticators := []authentication.Authenticator{authenticator}

token, err := authServer.TokenForKID(crypto.KID(), nil)
token, err := authServer.TokenForKID(crypto.KID(), nil, false)
require.NoError(t, err)

return token, authenticators
Expand Down Expand Up @@ -1307,7 +1452,7 @@ func TestOidcDiscovery(t *testing.T) {
tokens := make(map[string]string)

for _, c := range crypto {
token, err := authServer.TokenForKID(c.KID(), nil)
token, err := authServer.TokenForKID(c.KID(), nil, false)
require.NoError(t, err)

tokens[c.KID()] = token
Expand Down Expand Up @@ -1421,7 +1566,7 @@ func TestMultipleKeys(t *testing.T) {
tokens := make(map[string]string)

for _, c := range crypto {
token, err := authServer.TokenForKID(c.KID(), nil)
token, err := authServer.TokenForKID(c.KID(), nil, false)
require.NoError(t, err)

tokens[c.KID()] = token
Expand Down Expand Up @@ -1604,7 +1749,7 @@ func TestSupportedAlgorithms(t *testing.T) {

authenticators := []authentication.Authenticator{authenticator}

token, err := authServer.TokenForKID(crypto.KID(), nil)
token, err := authServer.TokenForKID(crypto.KID(), nil, false)
require.NoError(t, err)

return token, authenticators
Expand Down
2 changes: 1 addition & 1 deletion router-tests/header_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ func TestHeaderSetWithExpression(t *testing.T) {
authenticator, err := authentication.NewHttpHeaderAuthenticator(authOptions)
require.NoError(t, err)

token, err := authServer.TokenForKID(rsa1.KID(), map[string]any{"user_id": "TestId"})
token, err := authServer.TokenForKID(rsa1.KID(), map[string]any{"user_id": "TestId"}, false)
require.NoError(t, err)

testenv.Run(t, &testenv.Config{
Expand Down
11 changes: 9 additions & 2 deletions router-tests/jwks/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,18 @@ func (s *Server) Token(claims map[string]any) (string, error) {
return "", jwt.ErrInvalidKey
}

func (s *Server) TokenForKID(kid string, claims map[string]any) (string, error) {
func (s *Server) TokenForKID(kid string, claims map[string]any, useInvalidKID bool) (string, error) {
provider, ok := s.providers[kid]
if !ok {
if useInvalidKID {
// If we don't care about the kid, use any available provider
for _, pr := range s.providers {
provider = pr
break
}
} else if !ok {
return "", jwt.ErrInvalidKey
}

token := jwt.NewWithClaims(provider.SigningMethod(), jwt.MapClaims(claims))
token.Header[jwkset.HeaderKID] = kid
return token.SignedString(provider.PrivateKey())
Expand Down
12 changes: 12 additions & 0 deletions router-tests/testenv/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,15 @@ func AwaitChannelWithCloseWithT[A any](t *testing.T, timeout time.Duration, ch <
require.Fail(t, "unable to receive message before timeout", msgAndArgs...)
}
}

func AwaitFunc(t *testing.T, timeout time.Duration, testFunction func()) {
t.Helper()

doneCh := make(chan struct{})
go func() {
defer close(doneCh)
testFunction()
}()

AwaitChannelWithT(t, timeout, doneCh, func(t *testing.T, _ struct{}) {}, "the test function timed out")
}
Comment thread
SkArchon marked this conversation as resolved.
5 changes: 5 additions & 0 deletions router/core/supervisor_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ func setupAuthenticators(ctx context.Context, logger *zap.Logger, cfg *config.Co
KeyId: jwks.KeyId,

Audiences: jwks.Audiences,
RefreshUnknownKID: authentication.RefreshUnknownKIDConfig{
Enabled: jwks.RefreshUnknownKID.Enabled,
Interval: jwks.RefreshUnknownKID.Interval,
Burst: jwks.RefreshUnknownKID.Burst,
},
})
}

Expand Down
2 changes: 1 addition & 1 deletion router/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ require (
go.uber.org/ratelimit v0.3.1
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
golang.org/x/text v0.23.0
golang.org/x/time v0.9.0
)

require (
Expand Down Expand Up @@ -166,6 +165,7 @@ require (
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/time v0.9.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect
gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect
Expand Down
22 changes: 18 additions & 4 deletions router/pkg/authentication/jwks_token_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import (
"net/http"
"time"

"golang.org/x/time/rate"

"github.com/MicahParks/jwkset"
"github.com/MicahParks/keyfunc/v3"
"github.com/golang-jwt/jwt/v5"
"go.uber.org/zap"
"golang.org/x/time/rate"

"github.com/wundergraph/cosmo/router/internal/httpclient"
"go.uber.org/zap"
)

type TokenDecoder interface {
Expand Down Expand Up @@ -49,6 +49,14 @@ type JWKSConfig struct {
KeyId string

Audiences []string

RefreshUnknownKID RefreshUnknownKIDConfig
}

type RefreshUnknownKIDConfig struct {
Enabled bool
Interval time.Duration
Burst int
}

type audKey struct {
Expand Down Expand Up @@ -91,12 +99,18 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS

audiencesMap[key] = getAudienceSet(c.Audiences)

// Configure the rate limiter for refreshing unknown KIDs
var refreshLimiter *rate.Limiter
if c.RefreshUnknownKID.Enabled {
refreshLimiter = rate.NewLimiter(rate.Every(c.RefreshUnknownKID.Interval), c.RefreshUnknownKID.Burst)
}

Comment thread
SkArchon marked this conversation as resolved.
Outdated
jwksetHTTPClientOptions := jwkset.HTTPClientOptions{
HTTPURLs: map[string]jwkset.Storage{
c.URL: store,
},
PrioritizeHTTP: true,
RefreshUnknownKID: rate.NewLimiter(rate.Every(5*time.Minute), 1),
RefreshUnknownKID: refreshLimiter,
}

jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions)
Expand Down
13 changes: 10 additions & 3 deletions router/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,9 +466,10 @@ type OverridesConfiguration struct {
}

type JWKSConfiguration struct {
URL string `yaml:"url"`
Algorithms []string `yaml:"algorithms"`
RefreshInterval time.Duration `yaml:"refresh_interval" envDefault:"1m"`
URL string `yaml:"url"`
Algorithms []string `yaml:"algorithms"`
RefreshInterval time.Duration `yaml:"refresh_interval" envDefault:"1m"`
RefreshUnknownKID RefreshUnknownKIDConfig `yaml:",inline"`

Comment thread
coderabbitai[bot] marked this conversation as resolved.
// For secret based where we need to create a jwk entry with
// a key id and algorithm
Expand All @@ -480,6 +481,12 @@ type JWKSConfiguration struct {
Audiences []string `yaml:"audiences"`
}

type RefreshUnknownKIDConfig struct {
Enabled bool `yaml:"enabled" envDefault:"false"`
Interval time.Duration `yaml:"interval" envDefault:"1m"`
Burst int `yaml:"burst" envDefault:"2"`
}

type HeaderSource struct {
Type string `yaml:"type"`
Name string `yaml:"name"`
Expand Down
26 changes: 26 additions & 0 deletions router/pkg/config/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,32 @@
},
"description": "The interval at which the JWKs are refreshed. The period is specified as a string with a number and a unit, e.g. 10ms, 1s, 1m, 1h. The supported units are 'ms', 's', 'm', 'h'.",
"default": "1m"
},
"refresh_unknown_kid": {
"type": "object",
"description": "Controls rate-limited refresh behavior when a JWT KID is unknown.",
"additionalProperties": false,
"properties": {
"enabled": {
"type": "boolean",
"description": "Enable refresh attempts on unknown KID.",
"default": false
},
"interval": {
"type": "string",
"description": "Token refill interval for the rate limiter.",
"default": "1m",
"duration": {
"minimum": "1s"
}
},
"burst": {
"type": "integer",
"description": "Burst size for the rate limiter.",
"default": 2,
"minimum": 1
}
}
}
},
"oneOf": [
Expand Down
Loading
Loading