diff --git a/keyfunc.go b/keyfunc.go index ac63f9b..c33d6ed 100644 --- a/keyfunc.go +++ b/keyfunc.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "log/slog" + "slices" "time" "github.com/MicahParks/jwkset" @@ -33,6 +34,9 @@ type Options struct { Ctx context.Context Storage jwkset.Storage UseWhitelist []jwkset.USE + + // Custom Non Base on original keyfunc + AllowedAlgorithms []string } // Override is used to change specific default behaviors. @@ -53,9 +57,10 @@ type Override struct { } type keyfunc struct { - ctx context.Context - storage jwkset.Storage - useWhitelist []jwkset.USE + ctx context.Context + storage jwkset.Storage + useWhitelist []jwkset.USE + allowedAlgorithms []string } // New creates a new Keyfunc. @@ -68,9 +73,10 @@ func New(options Options) (Keyfunc, error) { return nil, fmt.Errorf("%w: no JWK Set storage given in options", ErrKeyfunc) } k := keyfunc{ - ctx: ctx, - storage: options.Storage, - useWhitelist: options.UseWhitelist, + ctx: ctx, + storage: options.Storage, + useWhitelist: options.UseWhitelist, + allowedAlgorithms: options.AllowedAlgorithms, } return k, nil } @@ -223,6 +229,16 @@ func (k keyfunc) KeyfuncCtx(ctx context.Context) jwt.Keyfunc { return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrKeyfunc) } + // When an algorithm is actually provided in the jwks the current keyfunc will validate the + // jwks algorithm with it. But when no algorithm is provided (alg: none or missing alg) + // the default keyfunc will not validate the algorithm as it has nothing to cross check. + if len(k.allowedAlgorithms) > 0 { + // This is a custom validation different from the original keyfunc.Keyfunc + if !slices.Contains(k.allowedAlgorithms, alg) { + return nil, fmt.Errorf("%w: could not find alg %s in allow list", ErrKeyfunc, alg) + } + } + jwk, err := k.storage.KeyRead(ctx, kid) if err != nil { return nil, fmt.Errorf("%w: could not read JWK from storage", errors.Join(err, ErrKeyfunc))