diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index f4205b11d5..38d2715a56 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -683,6 +683,27 @@ func TestAuthenticationWithCustomHeaders(t *testing.T) { func TestHttpJwksAuthorization(t *testing.T) { t.Parallel() + t.Run("startup should fail when duplicate URLs are specified", func(t *testing.T) { + t.Parallel() + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + _, err = authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: 2 * time.Second, + }, + { + URL: authServer.JWKSURL(), + RefreshInterval: 2 * time.Second, + }, + }) + + require.ErrorContains(t, err, "duplicate JWK URL found") + }) + t.Run("authentication should fail with no token", func(t *testing.T) { t.Parallel() @@ -765,7 +786,10 @@ func TestHttpJwksAuthorization(t *testing.T) { t.Cleanup(authServer2.Close) require.NoError(t, err) - token, err := authServer2.Token(nil) + // aud claim + token, err := authServer2.Token(map[string]any{ + "aud": "https://example.com", + }) require.NoError(t, err) authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ @@ -807,6 +831,28 @@ func TestHttpJwksAuthorization(t *testing.T) { } func TestNonHttpAuthorization(t *testing.T) { + t.Run("startup should fail when duplicate key ids are manually specified", func(t *testing.T) { + t.Parallel() + + secret := "example secret" + kid := "givenKID" + + _, err := authentication.NewJwksTokenDecoder(NewContextWithCancel(t), zap.NewNop(), []authentication.JWKSConfig{ + { + Secret: secret, + Algorithm: string(jwkset.AlgHS256), + KeyId: kid, + }, + { + Secret: secret, + Algorithm: string(jwkset.AlgHS256), + KeyId: kid, + }, + }) + + require.ErrorContains(t, err, "duplicate JWK keyid specified found") + }) + t.Run("authentication should succeed with a valid HS256 token", func(t *testing.T) { t.Parallel() @@ -820,7 +866,7 @@ func TestNonHttpAuthorization(t *testing.T) { }, }) - token := generateToken(t, kid, secret, jwt.SigningMethodHS256) + token := generateToken(t, kid, secret, jwt.SigningMethodHS256, nil) testenv.Run(t, &testenv.Config{ RouterOptions: []core.Option{ @@ -863,7 +909,7 @@ func TestNonHttpAuthorization(t *testing.T) { }, }) - token := generateToken(t, kid, secret, jwt.SigningMethodHS256) + token := generateToken(t, kid, secret, jwt.SigningMethodHS256, nil) testenv.Run(t, &testenv.Config{ RouterOptions: []core.Option{ @@ -902,7 +948,7 @@ func TestNonHttpAuthorization(t *testing.T) { }, }) - token := generateToken(t, "differentKID", secret, jwt.SigningMethodHS256) + token := generateToken(t, "differentKID", secret, jwt.SigningMethodHS256, nil) testenv.Run(t, &testenv.Config{ RouterOptions: []core.Option{ @@ -2028,6 +2074,436 @@ func TestAuthenticationOverWebsocket(t *testing.T) { }) } +func TestAudienceValidation(t *testing.T) { + t.Parallel() + + t.Run("authentication fails when there is no audience match", func(t *testing.T) { + t.Parallel() + + t.Run("with slice of string audiences in the token", func(t *testing.T) { + t.Parallel() + + t.Run("with http based configuration", func(t *testing.T) { + t.Parallel() + + tokenAudiences := []string{"aud1", "aud2"} + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + token, err := authServer.Token(map[string]any{"aud": tokenAudiences}) + require.NoError(t, err) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, + Audiences: []string{"aud3", "aud5"}, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + + t.Run("with secret based configuration", func(t *testing.T) { + t.Parallel() + + tokenAudiences := []string{"aud1", "aud2"} + + secret := "example secret" + kid := "givenKID" + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + Secret: secret, + Algorithm: string(jwkset.AlgHS256), + KeyId: kid, + Audiences: []string{"aud3", "aud5"}, + }, + }) + + token := generateToken(t, kid, secret, jwt.SigningMethodHS256, jwt.MapClaims{ + "aud": tokenAudiences, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + }) + + t.Run("with single string audience in the token", func(t *testing.T) { + t.Parallel() + + t.Run("with http based configuration", func(t *testing.T) { + t.Parallel() + + tokenAudiences := "aud1" + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + token, err := authServer.Token(map[string]any{"aud": tokenAudiences}) + require.NoError(t, err) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, + Audiences: []string{"aud3", "aud5"}, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + + t.Run("with secret based configuration", func(t *testing.T) { + t.Parallel() + + tokenAudience := "aud1" + + secret := "example secret" + kid := "givenKID" + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + Secret: secret, + Algorithm: string(jwkset.AlgHS256), + KeyId: kid, + Audiences: []string{"aud3", "aud5"}, + }, + }) + + token := generateToken(t, kid, secret, jwt.SigningMethodHS256, jwt.MapClaims{ + "aud": tokenAudience, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + }) + }) + + t.Run("authentication succeeds when there is an audience match", func(t *testing.T) { + t.Parallel() + + t.Run("with slice of string audiences in the token", func(t *testing.T) { + t.Parallel() + + t.Run("with http based configuration", func(t *testing.T) { + t.Parallel() + + matchingAudience := "matchingAudience" + tokenAudiences := []string{matchingAudience, "aud5"} + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + token, err := authServer.Token(map[string]any{"aud": tokenAudiences}) + require.NoError(t, err) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, + Audiences: []string{matchingAudience, "aud5"}, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) + }) + }) + + t.Run("with secret based configuration", func(t *testing.T) { + t.Parallel() + + matchingAud := "matchingAud" + tokenAudiences := []string{matchingAud, "aud2"} + + secret := "example secret" + kid := "givenKID" + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + Secret: secret, + Algorithm: string(jwkset.AlgHS256), + KeyId: kid, + Audiences: []string{matchingAud, "aud5"}, + }, + }) + + token := generateToken(t, kid, secret, jwt.SigningMethodHS256, jwt.MapClaims{ + "aud": tokenAudiences, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) + }) + }) + }) + + t.Run("with single string audience in the token", func(t *testing.T) { + t.Parallel() + + t.Run("with http based configuration", func(t *testing.T) { + t.Parallel() + + matchingAudience := "matchingAudience" + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + token, err := authServer.Token(map[string]any{"aud": matchingAudience}) + require.NoError(t, err) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, + Audiences: []string{matchingAudience, "aud5"}, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) + }) + }) + + t.Run("with secret based configuration", func(t *testing.T) { + t.Parallel() + + matchingAud := "matchingAudience" + + secret := "example secret" + kid := "givenKID" + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + Secret: secret, + Algorithm: string(jwkset.AlgHS256), + KeyId: kid, + Audiences: []string{matchingAud, "aud5"}, + }, + }) + + token := generateToken(t, kid, secret, jwt.SigningMethodHS256, jwt.MapClaims{ + "aud": matchingAud, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) + }) + }) + }) + }) + + t.Run("authentication fails when audience is invalid format", func(t *testing.T) { + t.Parallel() + + tokenAudiences := []bool{true, true} + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + token, err := authServer.Token(map[string]any{"aud": tokenAudiences}) + require.NoError(t, err) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, + Audiences: []string{"aud3", "aud5"}, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + require.Equal(t, "", res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + + }) + + t.Run("audience validation is ignored when expected aud is not provided", func(t *testing.T) { + t.Parallel() + + tokenAudiences := []bool{true, true} + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + token, err := authServer.Token(map[string]any{"aud": tokenAudiences}) + require.NoError(t, err) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) + }) + }) +} + func toJWKSConfig(url string, refresh time.Duration, allowedAlgorithms ...string) authentication.JWKSConfig { return authentication.JWKSConfig{ URL: url, @@ -2036,8 +2512,11 @@ func toJWKSConfig(url string, refresh time.Duration, allowedAlgorithms ...string } } -func generateToken(t *testing.T, kid string, secret string, signingMethod *jwt.SigningMethodHMAC) string { - token := jwt.New(signingMethod) +func generateToken(t *testing.T, kid string, secret string, signingMethod *jwt.SigningMethodHMAC, claims jwt.MapClaims) string { + if claims == nil { + claims = jwt.MapClaims{} + } + token := jwt.NewWithClaims(signingMethod, claims) token.Header[jwkset.HeaderKID] = kid jwtValue, err := token.SignedString([]byte(secret)) require.NoError(t, err) diff --git a/router/core/supervisor_instance.go b/router/core/supervisor_instance.go index f85699496e..b989d42599 100644 --- a/router/core/supervisor_instance.go +++ b/router/core/supervisor_instance.go @@ -259,6 +259,8 @@ func setupAuthenticators(ctx context.Context, logger *zap.Logger, cfg *config.Co Secret: jwks.Secret, Algorithm: jwks.Algorithm, KeyId: jwks.KeyId, + + Audiences: jwks.Audiences, }) } diff --git a/router/pkg/authentication/authentication.go b/router/pkg/authentication/authentication.go index 70963858d7..c08906ab09 100644 --- a/router/pkg/authentication/authentication.go +++ b/router/pkg/authentication/authentication.go @@ -77,6 +77,8 @@ func (a *authentication) Scopes() []string { return strings.Split(scopes, " ") } +var errUnacceptableAud = errors.New("audience match not found") + // Authenticate tries to authenticate the given Provider using the given authenticators. If any of // the authenticators succeeds, the Authentication result is returned with no error. If the Provider // has no authentication information, the Authentication result is nil with no error. If the authentication diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index 652ad385e8..53b252632c 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -2,6 +2,7 @@ package authentication import ( "context" + "errors" "fmt" "net/http" "time" @@ -20,12 +21,12 @@ type TokenDecoder interface { } type jwksTokenDecoder struct { - jwks keyfunc.Keyfunc + jwks jwt.Keyfunc } // Decode implements TokenDecoder. func (j *jwksTokenDecoder) Decode(tokenString string) (Claims, error) { - token, err := jwt.Parse(tokenString, j.jwks.Keyfunc) + token, err := jwt.Parse(tokenString, j.jwks) if err != nil { return nil, fmt.Errorf("could not validate token: %w", err) } @@ -46,16 +47,28 @@ type JWKSConfig struct { Secret string Algorithm string KeyId string + + Audiences []string } -func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKSConfig) (TokenDecoder, error) { +type audKey struct { + kid string + url string +} - remoteJWKSets := make(map[string]jwkset.Storage) +type audienceSet map[string]struct{} - given := jwkset.NewMemoryStorage() +func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKSConfig) (TokenDecoder, error) { + audiencesMap := make(map[audKey]audienceSet, len(configs)) + keyFuncMap := make(map[audKey]keyfunc.Keyfunc, len(configs)) for _, c := range configs { if c.URL != "" { + key := audKey{url: c.URL} + if _, ok := audiencesMap[key]; ok { + return nil, fmt.Errorf("duplicate JWK URL found: %s", c.URL) + } + l := logger.With(zap.String("url", c.URL)) jwksetHTTPStorageOptions := jwkset.HTTPClientStorageOptions{ @@ -76,8 +89,30 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS return nil, fmt.Errorf("failed to create HTTP client storage for JWK provider: %w", err) } - remoteJWKSets[c.URL] = store + audiencesMap[key] = getAudienceSet(c.Audiences) + + jwksetHTTPClientOptions := jwkset.HTTPClientOptions{ + HTTPURLs: map[string]jwkset.Storage{ + c.URL: store, + }, + PrioritizeHTTP: true, + RefreshUnknownKID: rate.NewLimiter(rate.Every(5*time.Minute), 1), + } + + jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions) + if err != nil { + return nil, err + } + keyFuncMap[key] = jwks + } else if c.Secret != "" { + key := audKey{kid: c.KeyId} + if _, ok := audiencesMap[key]; ok { + return nil, fmt.Errorf("duplicate JWK keyid specified found: %s", c.KeyId) + } + + given := jwkset.NewMemoryStorage() + marshalOptions := jwkset.JWKMarshalOptions{ Private: true, } @@ -103,21 +138,66 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS return nil, fmt.Errorf("failed to create JWK from secret: %w", err) } + audiencesMap[key] = getAudienceSet(c.Audiences) + err = given.KeyWrite(ctx, jwk) if err != nil { return nil, fmt.Errorf("failed to write JWK to storage: %w", err) } + + jwksetHTTPClientOptions := jwkset.HTTPClientOptions{ + Given: given, + PrioritizeHTTP: false, + } + + jwks, err := createKeyFunc(ctx, jwksetHTTPClientOptions) + if err != nil { + return nil, err + } + keyFuncMap[key] = jwks } } - jwksetHTTPClientOptions := jwkset.HTTPClientOptions{ - Given: given, - HTTPURLs: remoteJWKSets, - PrioritizeHTTP: false, - RefreshUnknownKID: rate.NewLimiter(rate.Every(5*time.Minute), 1), + keyFuncWrapper := jwt.Keyfunc(func(token *jwt.Token) (any, error) { + var errJoin error + for key, keyFunc := range keyFuncMap { + pub, err := keyFunc.Keyfunc(token) + if err != nil { + errJoin = errors.Join(errJoin, err) + continue + } + + expectedAudiences := audiencesMap[key] + if len(expectedAudiences) > 0 { + tokenAudiences, err := token.Claims.GetAudience() + if err != nil { + return nil, fmt.Errorf("could not get audiences from token claims: %w", err) + } + if !hasAudience(tokenAudiences, expectedAudiences) { + return nil, errUnacceptableAud + } + } + return pub, nil + } + + return nil, fmt.Errorf("no key found for token: %w", errors.Join(errJoin, jwt.ErrTokenUnverifiable)) + }) + + return &jwksTokenDecoder{ + jwks: keyFuncWrapper, + }, nil +} + +func getAudienceSet(audiences []string) audienceSet { + audSet := make(audienceSet, len(audiences)) + for _, aud := range audiences { + audSet[aud] = struct{}{} } + return audSet +} - combined, err := jwkset.NewHTTPClient(jwksetHTTPClientOptions) +func createKeyFunc(ctx context.Context, options jwkset.HTTPClientOptions) (keyfunc.Keyfunc, error) { + combined, err := jwkset.NewHTTPClient(options) if err != nil { return nil, fmt.Errorf("failed to create HTTP client storage for JWK provider: %w", err) } @@ -132,8 +212,15 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS if err != nil { return nil, fmt.Errorf("error initializing JWK: %w", err) } + return jwks, nil +} - return &jwksTokenDecoder{ - jwks: jwks, - }, nil +// hasAudience is a common intersection function to check on the token's audiences +func hasAudience(tokenAudiences []string, expectedAudiences audienceSet) bool { + for _, item := range tokenAudiences { + if _, found := expectedAudiences[item]; found { + return true + } + } + return false } diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index e066e6b91c..c51e2810f5 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -467,8 +467,11 @@ type JWKSConfiguration struct { // For secret based where we need to create a jwk entry with // a key id and algorithm Secret string `yaml:"secret"` - Algorithm string `yaml:"algorithm"` - KeyId string `yaml:"key_id"` + Algorithm string `yaml:"symmetric_algorithm"` + KeyId string `yaml:"header_key_id"` + + // Common + Audiences []string `yaml:"audiences"` } type HeaderSource struct { diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 152b80041e..54b2635627 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1649,22 +1649,29 @@ "description": "The URL of the JWKs. The JWKs are used to verify the JWT (JSON Web Token). The URL is specified as a string with the format 'scheme://host:port'.", "format": "http-url" }, + "audiences": { + "type": "array", + "description": "The audiences of the JWKs. The audiences are used to verify the JWT (JSON Web Token). The audiences are specified as a list of strings.", + "items": { + "type": "string" + } + }, "secret": { "type": "string", "description": "The secret of the JWKs" }, - "algorithm": { + "symmetric_algorithm": { "type": "string", - "description": "The algorithm used", + "description": "The symmetric algorithm used", "enum": [ "HS256", "HS384", "HS512" ] }, - "key_id": { + "header_key_id": { "type": "string", - "description": "The secret of the JWKs" + "description": "The KID header of the JWK token created using the secret" }, "algorithms": { "type": "array", @@ -1672,6 +1679,9 @@ "items": { "type": "string", "enum": [ + "HS256", + "HS384", + "HS512", "RS256", "RS384", "RS512", @@ -1697,10 +1707,10 @@ "oneOf": [ { "required": ["url"], - "not": { "anyOf": [{ "required": ["secret"] }, { "required": ["algorithm"] }, { "required": ["key_id"] }] } + "not": { "anyOf": [{ "required": ["secret"] }, { "required": ["symmetric_algorithm"] }, { "required": ["header_key_id"] }] } }, { - "required": ["secret", "algorithm", "key_id"], + "required": ["secret", "symmetric_algorithm", "header_key_id"], "not": { "anyOf": [{ "required": ["url"] }, { "required": ["algorithms"] }, { "required": ["refresh_interval"] }] } } ] diff --git a/router/pkg/config/config_test.go b/router/pkg/config/config_test.go index c82ff0c4e0..583db5fcf2 100644 --- a/router/pkg/config/config_test.go +++ b/router/pkg/config/config_test.go @@ -1227,9 +1227,9 @@ version: "1" authentication: jwt: jwks: - - key_id: "givenKID" + - header_key_id: "givenKID" secret: "example secret" - algorithm: HS512 + symmetric_algorithm: HS512 `) _, err := LoadConfig([]string{f}) @@ -1245,11 +1245,11 @@ version: "1" authentication: jwt: jwks: - - key_id: "givenKID" + - header_key_id: "givenKID" url: "http://url/valid.json" algorithms: [] secret: "example secret" - algorithm: HS512 + symmetric_algorithm: HS512 `) _, err := LoadConfig([]string{f}) require.ErrorContains(t, err, "at '/authentication/jwt/jwks/") @@ -1270,7 +1270,7 @@ authentication: `) _, err := LoadConfig([]string{f}) require.ErrorContains(t, err, "at '/authentication/jwt/jwks/") - require.ErrorContains(t, err, "missing properties 'algorithm', 'key_id'") + require.ErrorContains(t, err, "missing properties 'symmetric_algorithm', 'header_key_id'") }) diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index d5f57d5133..4426798c21 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -471,7 +471,8 @@ "RefreshInterval": 60000000000, "Secret": "", "Algorithm": "", - "KeyId": "" + "KeyId": "", + "Audiences": null }, { "URL": "https://example.com/.well-known/jwks2.json", @@ -482,7 +483,8 @@ "RefreshInterval": 120000000000, "Secret": "", "Algorithm": "", - "KeyId": "" + "KeyId": "", + "Audiences": null }, { "URL": "https://example.com/.well-known/jwks3.json", @@ -490,7 +492,8 @@ "RefreshInterval": 0, "Secret": "", "Algorithm": "", - "KeyId": "" + "KeyId": "", + "Audiences": null } ], "HeaderName": "Authorization",