Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions router-tests/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,48 @@ func TestHttpJwksAuthorization(t *testing.T) {
})
})

t.Run("authentication should not block with an invalid token on multiple calls", func(t *testing.T) {
t.Parallel()

authenticators, authServer := ConfigureAuth(t)
testenv.Run(t, &testenv.Config{
RouterOptions: []core.Option{
core.WithAccessController(core.NewAccessController(authenticators, true)),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
// This token has a static keyid
token, err := authServer.TokenForKID("wg_static_kid", nil, true)
require.NoError(t, err)

maxDuration := 5 * time.Second

doneCh := make(chan struct{})
go func() {
defer close(doneCh)
for range 5 {
Comment thread
SkArchon marked this conversation as resolved.
Outdated
func() {
// Operations with an invalid token should fail
header := http.Header{
"Authorization": []string{"Bearer " + token},
}
res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery))
require.NoError(t, err)
defer func() {
_ = res.Body.Close()
}()
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader))
data, err := io.ReadAll(res.Body)
require.NoError(t, err)
require.JSONEq(t, unauthorizedExpectedData, string(data))
}()
}
}()

testenv.AwaitChannelWithT(t, maxDuration, doneCh, func(t *testing.T, _ struct{}) {}, "test timed out")
})
})

}

func TestNonHttpAuthorization(t *testing.T) {
Expand Down Expand Up @@ -1182,7 +1224,7 @@ func TestAlgorithmMismatch(t *testing.T) {

authenticators := []authentication.Authenticator{authenticator}

token, err := authServer.TokenForKID(crypto.KID(), nil)
token, err := authServer.TokenForKID(crypto.KID(), nil, false)
require.NoError(t, err)

return token, authenticators
Expand Down Expand Up @@ -1307,7 +1349,7 @@ func TestOidcDiscovery(t *testing.T) {
tokens := make(map[string]string)

for _, c := range crypto {
token, err := authServer.TokenForKID(c.KID(), nil)
token, err := authServer.TokenForKID(c.KID(), nil, false)
require.NoError(t, err)

tokens[c.KID()] = token
Expand Down Expand Up @@ -1421,7 +1463,7 @@ func TestMultipleKeys(t *testing.T) {
tokens := make(map[string]string)

for _, c := range crypto {
token, err := authServer.TokenForKID(c.KID(), nil)
token, err := authServer.TokenForKID(c.KID(), nil, false)
require.NoError(t, err)

tokens[c.KID()] = token
Expand Down Expand Up @@ -1604,7 +1646,7 @@ func TestSupportedAlgorithms(t *testing.T) {

authenticators := []authentication.Authenticator{authenticator}

token, err := authServer.TokenForKID(crypto.KID(), nil)
token, err := authServer.TokenForKID(crypto.KID(), nil, false)
require.NoError(t, err)

return token, authenticators
Expand Down
2 changes: 1 addition & 1 deletion router-tests/header_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ func TestHeaderSetWithExpression(t *testing.T) {
authenticator, err := authentication.NewHttpHeaderAuthenticator(authOptions)
require.NoError(t, err)

token, err := authServer.TokenForKID(rsa1.KID(), map[string]any{"user_id": "TestId"})
token, err := authServer.TokenForKID(rsa1.KID(), map[string]any{"user_id": "TestId"}, false)
require.NoError(t, err)

testenv.Run(t, &testenv.Config{
Expand Down
12 changes: 10 additions & 2 deletions router-tests/jwks/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,19 @@ func (s *Server) Token(claims map[string]any) (string, error) {
return "", jwt.ErrInvalidKey
}

func (s *Server) TokenForKID(kid string, claims map[string]any) (string, error) {
func (s *Server) TokenForKID(kid string, claims map[string]any, useInvalidKID bool) (string, error) {
provider, ok := s.providers[kid]
if !ok {
if !useInvalidKID && !ok {
Comment thread
SkArchon marked this conversation as resolved.
Outdated
return "", jwt.ErrInvalidKey
} else if useInvalidKID {
// If we don't care about the kid we don't care about the provider
// we just get the first random provider provided
for _, pr := range s.providers {
provider = pr
break
}
Comment thread
endigma marked this conversation as resolved.
Outdated
}

token := jwt.NewWithClaims(provider.SigningMethod(), jwt.MapClaims(claims))
token.Header[jwkset.HeaderKID] = kid
return token.SignedString(provider.PrivateKey())
Expand Down
2 changes: 1 addition & 1 deletion router/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ require (
go.uber.org/ratelimit v0.3.1
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
golang.org/x/text v0.23.0
golang.org/x/time v0.9.0
)

require (
Expand Down Expand Up @@ -166,6 +165,7 @@ require (
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.38.0 // indirect
golang.org/x/time v0.9.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect
gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect
Expand Down
7 changes: 2 additions & 5 deletions router/pkg/authentication/jwks_token_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ import (
"github.com/MicahParks/jwkset"
"github.com/MicahParks/keyfunc/v3"
"github.com/golang-jwt/jwt/v5"
"go.uber.org/zap"
"golang.org/x/time/rate"

"github.com/wundergraph/cosmo/router/internal/httpclient"
"go.uber.org/zap"
)

type TokenDecoder interface {
Expand Down Expand Up @@ -95,8 +93,7 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS
HTTPURLs: map[string]jwkset.Storage{
c.URL: store,
},
PrioritizeHTTP: true,
RefreshUnknownKID: rate.NewLimiter(rate.Every(5*time.Minute), 1),
PrioritizeHTTP: true,
}

jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions)
Expand Down
Loading