diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 326b84749a..68a44e55c0 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -4,6 +4,8 @@ import ( "bytes" "crypto/rsa" "crypto/x509" + "encoding/base64" + "encoding/json" "encoding/pem" "io" "net/http" @@ -2228,7 +2230,6 @@ func TestSupportedAlgorithms(t *testing.T) { t.Parallel() body := testRequest(t, xEnv, authHeader(token), true) require.Equal(t, employeesExpectedData, string(body)) - }) t.Run("Should fail when providing no Token", func(t *testing.T) { @@ -2790,7 +2791,105 @@ func TestAudienceValidation(t *testing.T) { require.NoError(t, err) require.JSONEq(t, unauthorizedExpectedData, string(data)) }) + }) + + t.Run("audience validation succeeds even when one audience match fails", func(t *testing.T) { + t.Parallel() + + t.Run("with http based configuration", func(t *testing.T) { + t.Parallel() + + tokenAudiences := []string{"aud1"} + + authServer1, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer1.Close) + + authServer2, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer2.Close) + + token, err := authServer1.Token(map[string]any{"aud": tokenAudiences}) + require.NoError(t, err) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer2.JWKSURL(), + RefreshInterval: time.Second * 5, + Audiences: []string{"aud2"}, + }, + { + URL: authServer1.JWKSURL(), + RefreshInterval: time.Second * 5, + Audiences: []string{"aud1", "aud5"}, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) + }) + }) + + t.Run("with secret based configuration", func(t *testing.T) { + t.Parallel() + + matchingAud := "matchingAudience" + + secret := "example secret" + kid := "givenKID" + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + Secret: "secret", + Algorithm: string(jwkset.AlgHS256), + KeyId: "kid", + Audiences: []string{"aud3"}, + }, + { + Secret: secret, + Algorithm: string(jwkset.AlgHS256), + KeyId: kid, + Audiences: []string{matchingAud, "aud5"}, + }, + }) + + token := generateToken(t, kid, secret, jwt.SigningMethodHS256, jwt.MapClaims{ + "aud": matchingAud, + }) + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) + }) + }) }) t.Run("audience validation is ignored when expected aud is not provided", func(t *testing.T) { @@ -2831,6 +2930,178 @@ func TestAudienceValidation(t *testing.T) { require.Equal(t, employeesExpectedData, string(data)) }) }) + + t.Run("valid token with empty algorithm in JWKS", func(t *testing.T) { + t.Parallel() + + rsaCrypto, err := jwks.NewRSACrypto("", "", 2048) + require.NoError(t, err) + + authServer, err := jwks.NewServerWithCrypto(t, rsaCrypto) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, false)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + token, err := authServer.TokenWithOpts(nil, jwks.TokenOpts{ + AlgOverride: string(jwkset.AlgRS256), + }) + require.NoError(t, err) + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) + }) + }) + + t.Run("verify blocking invalid specified algorithm even though token is valid", func(t *testing.T) { + t.Parallel() + + rsaCrypto, err := jwks.NewRSACrypto("", "", 2048) + require.NoError(t, err) + + authServer, err := jwks.NewServerWithCrypto(t, rsaCrypto) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + allowedAlgorithm := jwkset.AlgRS256 + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, + AllowedAlgorithms: []string{string(allowedAlgorithm)}, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, false)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Fail with RS512 + token2, err := authServer.TokenWithOpts(nil, jwks.TokenOpts{ + AlgOverride: string(jwkset.AlgRS512), + }) + require.NoError(t, err) + res2, err := xEnv.MakeRequest(http.MethodPost, "/graphql", http.Header{ + "Authorization": []string{"Bearer " + token2}, + }, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { + _ = res2.Body.Close() + }() + require.Equal(t, http.StatusUnauthorized, res2.StatusCode) + }) + }) + + t.Run("verify blocking invalid algorithm", func(t *testing.T) { + t.Parallel() + + rsaCrypto, err := jwks.NewRSACrypto("", "R4ND0M", 2048) + require.NoError(t, err) + + authServer, err := jwks.NewServerWithCrypto(t, rsaCrypto) + require.NoError(t, err) + t.Cleanup(authServer.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + toJWKSConfig(authServer.JWKSURL(), time.Second*5), + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, true)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Manually craft a JWT with an unregistered/unknown alg value + hdr := map[string]any{"alg": "R4ND0M", "typ": "JWT", jwkset.HeaderKID: rsaCrypto.KID()} + pl := map[string]any{} + hBytes, err := json.Marshal(hdr) + require.NoError(t, err) + pBytes, err := json.Marshal(pl) + require.NoError(t, err) + signed := base64.RawURLEncoding.EncodeToString(hBytes) + "." + base64.RawURLEncoding.EncodeToString(pBytes) + ".bogus" + + header := http.Header{"Authorization": []string{"Bearer " + signed}} + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer func() { _ = res.Body.Close() }() + require.Equal(t, http.StatusUnauthorized, res.StatusCode) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.JSONEq(t, unauthorizedExpectedData, string(data)) + }) + }) + + t.Run("valid token for second entry with empty algorithm in JWKS", func(t *testing.T) { + t.Parallel() + + rsaCrypto, err := jwks.NewRSACrypto("", "", 2048) + require.NoError(t, err) + + authServer1, err := jwks.NewServerWithCrypto(t, rsaCrypto) + require.NoError(t, err) + t.Cleanup(authServer1.Close) + + authServer2, err := jwks.NewServerWithCrypto(t, rsaCrypto) + require.NoError(t, err) + t.Cleanup(authServer2.Close) + + authenticators := ConfigureAuthWithJwksConfig(t, []authentication.JWKSConfig{ + { + URL: authServer1.JWKSURL(), + RefreshInterval: time.Second * 5, + AllowedAlgorithms: []string{string(jwkset.AlgRS256)}, + }, + { + URL: authServer2.JWKSURL(), + RefreshInterval: time.Second * 5, + AllowedAlgorithms: []string{string(jwkset.AlgRS512)}, + }, + }) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, false)), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // Operations with a token should succeed + token, err := authServer2.TokenWithOpts(nil, jwks.TokenOpts{ + AlgOverride: string(jwkset.AlgRS512), + }) + require.NoError(t, err) + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(employeesQuery)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, JwksName, res.Header.Get(xAuthenticatedByHeader)) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, employeesExpectedData, string(data)) + }) + }) } func toJWKSConfig(url string, refresh time.Duration, allowedAlgorithms ...string) authentication.JWKSConfig { diff --git a/router-tests/go.mod b/router-tests/go.mod index ffe6d65377..b49dec69d0 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -3,7 +3,7 @@ module github.com/wundergraph/cosmo/router-tests go 1.25 require ( - github.com/MicahParks/jwkset v0.9.0 + github.com/MicahParks/jwkset v0.11.0 github.com/buger/jsonparser v1.1.1 github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 github.com/golang-jwt/jwt/v5 v5.2.2 @@ -45,7 +45,7 @@ require ( connectrpc.com/connect v1.16.2 // indirect github.com/99designs/gqlgen v0.17.76 // indirect github.com/KimMachineGun/automemlimit v0.6.1 // indirect - github.com/MicahParks/keyfunc/v3 v3.3.5 // indirect + github.com/MicahParks/keyfunc/v3 v3.6.2 // indirect github.com/agnivade/levenshtein v1.2.1 // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect diff --git a/router-tests/go.sum b/router-tests/go.sum index c6ec48801f..65eaff9693 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -5,10 +5,10 @@ github.com/99designs/gqlgen v0.17.76/go.mod h1:miiU+PkAnTIDKMQ1BseUOIVeQHoiwYDZG github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/KimMachineGun/automemlimit v0.6.1 h1:ILa9j1onAAMadBsyyUJv5cack8Y1WT26yLj/V+ulKp8= github.com/KimMachineGun/automemlimit v0.6.1/go.mod h1:T7xYht7B8r6AG/AqFcUdc7fzd2bIdBKmepfP2S1svPY= -github.com/MicahParks/jwkset v0.9.0 h1:xDlGu6mZJdJ+mgAI4mIRqWm2p8Vrx0U98LMgRObw46M= -github.com/MicahParks/jwkset v0.9.0/go.mod h1:fVrj6TmG1aKlJEeceAz7JsXGTXEn72zP1px3us53JrA= -github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo= -github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8= +github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= +github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= +github.com/MicahParks/keyfunc/v3 v3.6.2 h1:82rre60MKw4r117ew5/T4m1AphgkpCOYry0RPbFUY3w= +github.com/MicahParks/keyfunc/v3 v3.6.2/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0= github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo= github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y= github.com/agnivade/levenshtein v1.2.1 h1:EHBY3UOn1gwdy/VbFwgo4cxecRznFk7fKWN1KOX7eoM= diff --git a/router-tests/jwks/jwks.go b/router-tests/jwks/jwks.go index 6b77a76812..9e0cabcc8e 100644 --- a/router-tests/jwks/jwks.go +++ b/router-tests/jwks/jwks.go @@ -36,13 +36,28 @@ func (s *Server) Close() { s.httpServer.Close() } +type TokenOpts struct { + AlgOverride string +} + func (s *Server) Token(claims map[string]any) (string, error) { + return s.TokenWithOpts(claims, TokenOpts{AlgOverride: ""}) +} + +func (s *Server) TokenWithOpts(claims map[string]any, tokenOpts TokenOpts) (string, error) { if len(s.providers) == 0 { return "", jwt.ErrInvalidKey } for kid, pr := range s.providers { - token := jwt.NewWithClaims(pr.SigningMethod(), jwt.MapClaims(claims)) + method := pr.SigningMethod() + if tokenOpts.AlgOverride != "" { + method = jwt.GetSigningMethod(tokenOpts.AlgOverride) + if method == nil { + return "", fmt.Errorf("unsupported signing method: %s", tokenOpts.AlgOverride) + } + } + token := jwt.NewWithClaims(method, jwt.MapClaims(claims)) token.Header[jwkset.HeaderKID] = kid return token.SignedString(pr.PrivateKey()) } diff --git a/router/go.mod b/router/go.mod index 180c9f51d0..3bf34dfab4 100644 --- a/router/go.mod +++ b/router/go.mod @@ -58,8 +58,8 @@ require ( require ( github.com/KimMachineGun/automemlimit v0.6.1 - github.com/MicahParks/jwkset v0.9.0 - github.com/MicahParks/keyfunc/v3 v3.3.5 + github.com/MicahParks/jwkset v0.11.0 + github.com/MicahParks/keyfunc/v3 v3.6.2 github.com/alicebob/miniredis/v2 v2.34.0 github.com/caarlos0/env/v11 v11.3.1 github.com/cep21/circuit/v4 v4.0.0 diff --git a/router/go.sum b/router/go.sum index 0263992f20..0bcf0aa4f0 100644 --- a/router/go.sum +++ b/router/go.sum @@ -5,10 +5,10 @@ github.com/99designs/gqlgen v0.17.49/go.mod h1:tC8YFVZMed81x7UJ7ORUwXF4Kn6SXuucF github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/KimMachineGun/automemlimit v0.6.1 h1:ILa9j1onAAMadBsyyUJv5cack8Y1WT26yLj/V+ulKp8= github.com/KimMachineGun/automemlimit v0.6.1/go.mod h1:T7xYht7B8r6AG/AqFcUdc7fzd2bIdBKmepfP2S1svPY= -github.com/MicahParks/jwkset v0.9.0 h1:xDlGu6mZJdJ+mgAI4mIRqWm2p8Vrx0U98LMgRObw46M= -github.com/MicahParks/jwkset v0.9.0/go.mod h1:fVrj6TmG1aKlJEeceAz7JsXGTXEn72zP1px3us53JrA= -github.com/MicahParks/keyfunc/v3 v3.3.5 h1:7ceAJLUAldnoueHDNzF8Bx06oVcQ5CfJnYwNt1U3YYo= -github.com/MicahParks/keyfunc/v3 v3.3.5/go.mod h1:SdCCyMJn/bYqWDvARspC6nCT8Sk74MjuAY22C7dCST8= +github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOhoyQ= +github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= +github.com/MicahParks/keyfunc/v3 v3.6.2 h1:82rre60MKw4r117ew5/T4m1AphgkpCOYry0RPbFUY3w= +github.com/MicahParks/keyfunc/v3 v3.6.2/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0= github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE= diff --git a/router/pkg/authentication/jwks_token_decoder.go b/router/pkg/authentication/jwks_token_decoder.go index c6ae6c794e..1685c2285a 100644 --- a/router/pkg/authentication/jwks_token_decoder.go +++ b/router/pkg/authentication/jwks_token_decoder.go @@ -5,12 +5,14 @@ import ( "errors" "fmt" "net/http" + "slices" "time" "golang.org/x/time/rate" "github.com/MicahParks/jwkset" "github.com/MicahParks/keyfunc/v3" + "github.com/golang-jwt/jwt/v5" "github.com/wundergraph/cosmo/router/internal/httpclient" "go.uber.org/zap" @@ -60,20 +62,27 @@ type RefreshUnknownKIDConfig struct { MaxWait time.Duration } -type audKey struct { +type configKey struct { kid string url string } type audienceSet map[string]struct{} +type keyFuncEntry struct { + jwks keyfunc.Keyfunc + aud audienceSet + allowedAlgorithms []string +} + func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKSConfig) (TokenDecoder, error) { - audiencesMap := make(map[audKey]audienceSet, len(configs)) - keyFuncMap := make(map[audKey]keyfunc.Keyfunc, len(configs)) + // Audience map is used to validate duplicate configs + audiencesMap := make(map[configKey]audienceSet, len(configs)) + entries := make([]keyFuncEntry, 0, len(configs)) for _, c := range configs { if c.URL != "" { - key := audKey{url: c.URL} + key := configKey{url: c.URL} if _, ok := audiencesMap[key]; ok { return nil, fmt.Errorf("duplicate JWK URL found: %s", c.URL) } @@ -90,7 +99,7 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS l.Error("Failed to refresh HTTP JWK Set from remote HTTP resource.", zap.Error(err)) }, RefreshInterval: c.RefreshInterval, - Storage: NewValidationStore(logger, nil, c.AllowedAlgorithms), + Storage: jwkset.NewMemoryStorage(), } store, err := jwkset.NewStorageFromHTTP(c.URL, jwksetHTTPStorageOptions) @@ -117,14 +126,17 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS if err != nil { return nil, err } - keyFuncMap[key] = jwks + entries = append(entries, keyFuncEntry{ + jwks: jwks, + aud: audiencesMap[key], + allowedAlgorithms: c.AllowedAlgorithms, + }) } else if c.Secret != "" { - key := audKey{kid: c.KeyId} + key := configKey{kid: c.KeyId} if _, ok := audiencesMap[key]; ok { return nil, fmt.Errorf("duplicate JWK keyid specified found: %s", c.KeyId) } - given := jwkset.NewMemoryStorage() marshalOptions := jwkset.JWKMarshalOptions{ @@ -168,29 +180,55 @@ func NewJwksTokenDecoder(ctx context.Context, logger *zap.Logger, configs []JWKS if err != nil { return nil, err } - keyFuncMap[key] = jwks + entries = append(entries, keyFuncEntry{ + jwks: jwks, + aud: audiencesMap[key], + }) } } keyFuncWrapper := jwt.Keyfunc(func(token *jwt.Token) (any, error) { var errJoin error - for key, keyFunc := range keyFuncMap { - pub, err := keyFunc.Keyfunc(token) - if err != nil { - errJoin = errors.Join(errJoin, err) - continue - } - - expectedAudiences := audiencesMap[key] - if len(expectedAudiences) > 0 { + for _, entry := range entries { + if len(entry.aud) > 0 { tokenAudiences, err := token.Claims.GetAudience() if err != nil { - return nil, fmt.Errorf("could not get audiences from token claims: %w", err) + errJoin = errors.Join(errJoin, fmt.Errorf("could not get audiences from token claims: %w", err)) + continue + } + if !hasAudience(tokenAudiences, entry.aud) { + errJoin = errors.Join(errJoin, errUnacceptableAud) + continue } - if !hasAudience(tokenAudiences, expectedAudiences) { - return nil, errUnacceptableAud + } + + // When an algorithm is actually provided in the jwks the current keyfunc will validate the + // jwks algorithm with it. But when no algorithm is provided (alg: none or missing alg) + // the default keyfunc will not validate the algorithm as it has nothing to cross check. + if len(entry.allowedAlgorithms) > 0 { + algInter, ok := token.Header["alg"] + if !ok { + errJoin = errors.Join(errJoin, fmt.Errorf("%w: could not find alg in JWT header", keyfunc.ErrKeyfunc)) + continue + } + alg, ok := algInter.(string) + if !ok { + errJoin = errors.Join(errJoin, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, keyfunc.ErrKeyfunc)) + continue + } + + // This is a custom validation different from the original keyfunc.Keyfunc + if !slices.Contains(entry.allowedAlgorithms, alg) { + errJoin = errors.Join(errJoin, fmt.Errorf("%w: could not find alg %s in allow list", keyfunc.ErrKeyfunc, alg)) + continue } } + + pub, err := entry.jwks.Keyfunc(token) + if err != nil { + errJoin = errors.Join(errJoin, err) + continue + } return pub, nil } diff --git a/router/pkg/authentication/validation_store.go b/router/pkg/authentication/validation_store.go deleted file mode 100644 index d447c26710..0000000000 --- a/router/pkg/authentication/validation_store.go +++ /dev/null @@ -1,151 +0,0 @@ -package authentication - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/MicahParks/jwkset" - "go.uber.org/zap" -) - -var _ jwkset.Storage = (*validationStore)(nil) - -type validationStore struct { - logger *zap.Logger - algs map[string]struct{} - inner jwkset.Storage -} - -var supportedAlgorithms = map[string]struct{}{ - "HS256": {}, - "HS384": {}, - "HS512": {}, - "RS256": {}, - "RS384": {}, - "RS512": {}, - "PS256": {}, - "PS384": {}, - "PS512": {}, - "ES256": {}, - "ES384": {}, - "ES512": {}, - "EdDSA": {}, -} - -func NewValidationStore(logger *zap.Logger, inner jwkset.Storage, algs []string) jwkset.Storage { - if inner == nil { - inner = jwkset.NewMemoryStorage() - } - - if logger == nil { - logger = zap.NewNop() - } - - algSet := make(map[string]struct{}, len(algs)) - - store := &validationStore{ - logger: logger, - inner: inner, - algs: supportedAlgorithms, - } - - if len(algs) == 0 { - return store - } - - for _, alg := range algs { - if _, ok := supportedAlgorithms[alg]; !ok { - logger.Warn("Unsupported algorithm", zap.String("algorithm", alg)) - continue - } - algSet[alg] = struct{}{} - } - - store.algs = algSet - return store -} - -func (v *validationStore) KeyDelete(ctx context.Context, keyID string) (ok bool, err error) { - return v.inner.KeyDelete(ctx, keyID) -} - -func (v *validationStore) KeyRead(ctx context.Context, keyID string) (jwkset.JWK, error) { - key, err := v.inner.KeyRead(ctx, keyID) - if err != nil { - return key, err - } - - m := key.Marshal() - if _, ok := v.algs[m.ALG.String()]; ok { - return key, nil - } - - return jwkset.JWK{}, fmt.Errorf("key with ID %q has an unsupported algorithm %s", keyID, m.ALG.String()) -} - -func (v *validationStore) KeyReadAll(ctx context.Context) ([]jwkset.JWK, error) { - keys, err := v.inner.KeyReadAll(ctx) - if err != nil { - return nil, err - } - - filter := make([]jwkset.JWK, 0, len(keys)) - - for _, k := range keys { - m := k.Marshal() - if _, ok := v.algs[m.ALG.String()]; ok { - filter = append(filter, k) - } - } - - return filter, nil -} - -func (v *validationStore) KeyReplaceAll(ctx context.Context, given []jwkset.JWK) error { - filtered := make([]jwkset.JWK, 0) - for _, k := range given { - m := k.Marshal() - if _, ok := v.algs[m.ALG.String()]; ok { - filtered = append(filtered, k) - } - } - return v.inner.KeyReplaceAll(ctx, filtered) -} - -func (v *validationStore) KeyWrite(ctx context.Context, jwk jwkset.JWK) error { - jwkMarshal := jwk.Marshal() - if _, ok := v.algs[jwkMarshal.ALG.String()]; !ok { - // We should not return an error here. If JWKS are configured for multiple applications, we should only add the - // supported keys to the token decoder store and not prevent the refresh entirely. - // In case we are receiving a key with an unsupported algorithm we log a warning instead. - v.logger.Warn("Skipping key with unsupported algorithm", zap.String("keyID", jwkMarshal.KID), zap.String("algorithm", jwkMarshal.ALG.String())) - return nil - } - - return v.inner.KeyWrite(ctx, jwk) -} - -func (v *validationStore) JSON(ctx context.Context) (json.RawMessage, error) { - return v.inner.JSON(ctx) -} - -func (v *validationStore) JSONPublic(ctx context.Context) (json.RawMessage, error) { - return v.inner.JSONPublic(ctx) -} - -func (v *validationStore) JSONPrivate(ctx context.Context) (json.RawMessage, error) { - return v.inner.JSONPrivate(ctx) -} - -func (v *validationStore) JSONWithOptions(ctx context.Context, marshalOptions jwkset.JWKMarshalOptions, validationOptions jwkset.JWKValidateOptions) (json.RawMessage, error) { - return v.inner.JSONWithOptions(ctx, marshalOptions, validationOptions) -} - -func (v *validationStore) Marshal(ctx context.Context) (jwkset.JWKSMarshal, error) { - return v.inner.Marshal(ctx) -} - -func (v *validationStore) MarshalWithOptions(ctx context.Context, marshalOptions jwkset.JWKMarshalOptions, validationOptions jwkset.JWKValidateOptions) (jwkset.JWKSMarshal, error) { - return v.inner.MarshalWithOptions(ctx, marshalOptions, validationOptions) -} diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index d6f4b2b466..f649a952b6 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -1781,6 +1781,9 @@ }, { "required": ["refresh_interval"] + }, + { + "required": ["refresh_unknown_kid"] } ] }