Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions keyfunc.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"log/slog"
"slices"
"time"

"github.com/MicahParks/jwkset"
Expand All @@ -33,6 +34,9 @@ type Options struct {
Ctx context.Context
Storage jwkset.Storage
UseWhitelist []jwkset.USE

// Custom Non Base on original keyfunc
AllowedAlgorithms []string
}

// Override is used to change specific default behaviors.
Expand All @@ -53,9 +57,10 @@ type Override struct {
}

type keyfunc struct {
ctx context.Context
storage jwkset.Storage
useWhitelist []jwkset.USE
ctx context.Context
storage jwkset.Storage
useWhitelist []jwkset.USE
allowedAlgorithms []string
}

// New creates a new Keyfunc.
Expand All @@ -68,9 +73,10 @@ func New(options Options) (Keyfunc, error) {
return nil, fmt.Errorf("%w: no JWK Set storage given in options", ErrKeyfunc)
}
k := keyfunc{
ctx: ctx,
storage: options.Storage,
useWhitelist: options.UseWhitelist,
ctx: ctx,
storage: options.Storage,
useWhitelist: options.UseWhitelist,
allowedAlgorithms: options.AllowedAlgorithms,
}
return k, nil
}
Expand Down Expand Up @@ -223,6 +229,16 @@ func (k keyfunc) KeyfuncCtx(ctx context.Context) jwt.Keyfunc {
return nil, fmt.Errorf(`%w: the JWT header did not contain the "alg" parameter, which is required by RFC 7515 section 4.1.1`, ErrKeyfunc)
}

// When an algorithm is actually provided in the jwks the current keyfunc will validate the
// jwks algorithm with it. But when no algorithm is provided (alg: none or missing alg)
// the default keyfunc will not validate the algorithm as it has nothing to cross check.
if len(k.allowedAlgorithms) > 0 {
// This is a custom validation different from the original keyfunc.Keyfunc
if !slices.Contains(k.allowedAlgorithms, alg) {
return nil, fmt.Errorf("%w: could not find alg %s in allow list", ErrKeyfunc, alg)
}
}

jwk, err := k.storage.KeyRead(ctx, kid)
if err != nil {
return nil, fmt.Errorf("%w: could not read JWK from storage", errors.Join(err, ErrKeyfunc))
Expand Down