Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
491 changes: 485 additions & 6 deletions router-tests/authentication_test.go

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions router/core/supervisor_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ func setupAuthenticators(ctx context.Context, logger *zap.Logger, cfg *config.Co
Secret: jwks.Secret,
Algorithm: jwks.Algorithm,
KeyId: jwks.KeyId,

Audiences: jwks.Audiences,
})
}

Expand Down
2 changes: 2 additions & 0 deletions router/pkg/authentication/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ func (a *authentication) Scopes() []string {
return strings.Split(scopes, " ")
}

var errUnacceptableAud = errors.New("audience match not found")

// Authenticate tries to authenticate the given Provider using the given authenticators. If any of
// the authenticators succeeds, the Authentication result is returned with no error. If the Provider
// has no authentication information, the Authentication result is nil with no error. If the authentication
Expand Down
117 changes: 102 additions & 15 deletions router/pkg/authentication/jwks_token_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package authentication

import (
"context"
"errors"
"fmt"
"net/http"
"time"
Expand All @@ -20,12 +21,12 @@ type TokenDecoder interface {
}

type jwksTokenDecoder struct {
jwks keyfunc.Keyfunc
jwks jwt.Keyfunc
}

// Decode implements TokenDecoder.
func (j *jwksTokenDecoder) Decode(tokenString string) (Claims, error) {
token, err := jwt.Parse(tokenString, j.jwks.Keyfunc)
token, err := jwt.Parse(tokenString, j.jwks)
if err != nil {
return nil, fmt.Errorf("could not validate token: %w", err)
}
Expand All @@ -46,16 +47,28 @@ type JWKSConfig struct {
Secret string
Algorithm string
KeyId string

Audiences []string
}

func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKSConfig) (TokenDecoder, error) {
type audKey struct {
kid string
url string
}

remoteJWKSets := make(map[string]jwkset.Storage)
type audienceSet map[string]struct{}

given := jwkset.NewMemoryStorage()
func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKSConfig) (TokenDecoder, error) {
audiencesMap := make(map[audKey]audienceSet, len(configs))
keyFuncMap := make(map[audKey]keyfunc.Keyfunc, len(configs))

for _, c := range configs {
if c.URL != "" {
key := audKey{url: c.URL}
if _, ok := audiencesMap[key]; ok {
return nil, fmt.Errorf("duplicate JWK URL found: %s", c.URL)
}

l := logger.With(zap.String("url", c.URL))

jwksetHTTPStorageOptions := jwkset.HTTPClientStorageOptions{
Expand All @@ -76,8 +89,30 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS
return nil, fmt.Errorf("failed to create HTTP client storage for JWK provider: %w", err)
}

remoteJWKSets[c.URL] = store
audiencesMap[key] = getAudienceSet(c.Audiences)

jwksetHTTPClientOptions := jwkset.HTTPClientOptions{
HTTPURLs: map[string]jwkset.Storage{
c.URL: store,
},
PrioritizeHTTP: true,
RefreshUnknownKID: rate.NewLimiter(rate.Every(5*time.Minute), 1),
}

jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions)
if err != nil {
return nil, err
}
keyFuncMap[key] = jwks

} else if c.Secret != "" {
key := audKey{kid: c.KeyId}
if _, ok := audiencesMap[key]; ok {
return nil, fmt.Errorf("duplicate JWK keyid specified found: %s", c.KeyId)
}

given := jwkset.NewMemoryStorage()

marshalOptions := jwkset.JWKMarshalOptions{
Private: true,
}
Expand All @@ -103,21 +138,66 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS
return nil, fmt.Errorf("failed to create JWK from secret: %w", err)
}

audiencesMap[key] = getAudienceSet(c.Audiences)

err = given.KeyWrite(ctx, jwk)
if err != nil {
return nil, fmt.Errorf("failed to write JWK to storage: %w", err)
}

jwksetHTTPClientOptions := jwkset.HTTPClientOptions{
Given: given,
PrioritizeHTTP: false,
}

jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions)
if err != nil {
return nil, err
}
keyFuncMap[key] = jwks
}
}

jwksetHTTPClientOptions := jwkset.HTTPClientOptions{
Given: given,
HTTPURLs: remoteJWKSets,
PrioritizeHTTP: false,
RefreshUnknownKID: rate.NewLimiter(rate.Every(5*time.Minute), 1),
keyFuncWrapper := jwt.Keyfunc(func(token *jwt.Token) (any, error) {
var errJoin error
for key, keyFunc := range keyFuncMap {
pub, err := keyFunc.Keyfunc(token)
if err != nil {
errJoin = errors.Join(errJoin, err)
continue
}

expectedAudiences := audiencesMap[key]
if len(expectedAudiences) > 0 {
tokenAudiences, err := token.Claims.GetAudience()
if err != nil {
return nil, fmt.Errorf("could not get audiences from token claims: %w", err)
}
if !hasAudience(tokenAudiences, expectedAudiences) {
return nil, errUnacceptableAud
}
}
return pub, nil
}

return nil, fmt.Errorf("no key found for token: %w", errors.Join(errJoin, jwt.ErrTokenUnverifiable))
})

return &jwksTokenDecoder{
jwks: keyFuncWrapper,
}, nil
}

func getAudienceSet(audiences []string) audienceSet {
audSet := make(audienceSet, len(audiences))
for _, aud := range audiences {
audSet[aud] = struct{}{}
}
return audSet
}

combined, err := jwkset.NewHTTPClient(jwksetHTTPClientOptions)
func createKeyFunc(ctx context.Context, options jwkset.HTTPClientOptions) (keyfunc.Keyfunc, error) {
combined, err := jwkset.NewHTTPClient(options)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client storage for JWK provider: %w", err)
}
Expand All @@ -132,8 +212,15 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS
if err != nil {
return nil, fmt.Errorf("error initializing JWK: %w", err)
}
return jwks, nil
}

return &jwksTokenDecoder{
jwks: jwks,
}, nil
// hasAudience is a common intersection function to check on the token's audiences
func hasAudience(tokenAudiences []string, expectedAudiences audienceSet) bool {
for _, item := range tokenAudiences {
if _, found := expectedAudiences[item]; found {
return true
}
}
return false
}
7 changes: 5 additions & 2 deletions router/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,11 @@ type JWKSConfiguration struct {
// For secret based where we need to create a jwk entry with
// a key id and algorithm
Secret string `yaml:"secret"`
Algorithm string `yaml:"algorithm"`
KeyId string `yaml:"key_id"`
Algorithm string `yaml:"symmetric_algorithm"`
KeyId string `yaml:"header_key_id"`

// Common
Audiences []string `yaml:"audiences"`
}

type HeaderSource struct {
Expand Down
22 changes: 16 additions & 6 deletions router/pkg/config/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1649,29 +1649,39 @@
"description": "The URL of the JWKs. The JWKs are used to verify the JWT (JSON Web Token). The URL is specified as a string with the format 'scheme://host:port'.",
"format": "http-url"
},
"audiences": {
"type": "array",
"description": "The audiences of the JWKs. The audiences are used to verify the JWT (JSON Web Token). The audiences are specified as a list of strings.",
"items": {
"type": "string"
}
},
"secret": {
"type": "string",
"description": "The secret of the JWKs"
},
"algorithm": {
"symmetric_algorithm": {
"type": "string",
"description": "The algorithm used",
"description": "The symmetric algorithm used",
"enum": [
"HS256",
"HS384",
"HS512"
]
},
"key_id": {
"header_key_id": {
"type": "string",
"description": "The secret of the JWKs"
"description": "The KID header of the JWK token created using the secret"
},
"algorithms": {
"type": "array",
"description": "The allowed algorithms for the keys that are retrieved from the JWKs. An empty list means that all algorithms are allowed.",
"items": {
"type": "string",
"enum": [
"HS256",
"HS384",
"HS512",
"RS256",
"RS384",
"RS512",
Expand All @@ -1697,10 +1707,10 @@
"oneOf": [
{
"required": ["url"],
"not": { "anyOf": [{ "required": ["secret"] }, { "required": ["algorithm"] }, { "required": ["key_id"] }] }
"not": { "anyOf": [{ "required": ["secret"] }, { "required": ["symmetric_algorithm"] }, { "required": ["header_key_id"] }] }
},
{
"required": ["secret", "algorithm", "key_id"],
"required": ["secret", "symmetric_algorithm", "header_key_id"],
"not": { "anyOf": [{ "required": ["url"] }, { "required": ["algorithms"] }, { "required": ["refresh_interval"] }] }
}
]
Expand Down
10 changes: 5 additions & 5 deletions router/pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1227,9 +1227,9 @@ version: "1"
authentication:
jwt:
jwks:
- key_id: "givenKID"
- header_key_id: "givenKID"
secret: "example secret"
algorithm: HS512
symmetric_algorithm: HS512

`)
_, err := LoadConfig([]string{f})
Expand All @@ -1245,11 +1245,11 @@ version: "1"
authentication:
jwt:
jwks:
- key_id: "givenKID"
- header_key_id: "givenKID"
url: "http://url/valid.json"
algorithms: []
secret: "example secret"
algorithm: HS512
symmetric_algorithm: HS512
`)
_, err := LoadConfig([]string{f})
require.ErrorContains(t, err, "at '/authentication/jwt/jwks/")
Expand All @@ -1270,7 +1270,7 @@ authentication:
`)
_, err := LoadConfig([]string{f})
require.ErrorContains(t, err, "at '/authentication/jwt/jwks/")
require.ErrorContains(t, err, "missing properties 'algorithm', 'key_id'")
require.ErrorContains(t, err, "missing properties 'symmetric_algorithm', 'header_key_id'")

})

Expand Down
9 changes: 6 additions & 3 deletions router/pkg/config/testdata/config_full.json
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,8 @@
"RefreshInterval": 60000000000,
"Secret": "",
"Algorithm": "",
"KeyId": ""
"KeyId": "",
"Audiences": null
},
{
"URL": "https://example.com/.well-known/jwks2.json",
Expand All @@ -482,15 +483,17 @@
"RefreshInterval": 120000000000,
"Secret": "",
"Algorithm": "",
"KeyId": ""
"KeyId": "",
"Audiences": null
},
{
"URL": "https://example.com/.well-known/jwks3.json",
"Algorithms": null,
"RefreshInterval": 0,
"Secret": "",
"Algorithm": "",
"KeyId": ""
"KeyId": "",
"Audiences": null
}
],
"HeaderName": "Authorization",
Expand Down
Loading