Skip to content

Commit

Permalink
Support almost all base64 options
Browse files Browse the repository at this point in the history
  • Loading branch information
oxisto committed Jul 4, 2024
1 parent f0fa303 commit 7684d3e
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 34 deletions.
21 changes: 11 additions & 10 deletions encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@ package jwt

import "io"

// Base64Encoder is an interface that allows to implement custom Base64 encoding
// algorithms.
type Base64EncodeFunc func(src []byte) string
type Base64Encoding interface {
EncodeToString(src []byte) string
DecodeString(s string) ([]byte, error)
}

// Base64Decoder is an interface that allows to implement custom Base64 decoding
// algorithms.
type Base64DecodeFunc func(s string) ([]byte, error)
type Stricter[T Base64Encoding] interface {
Strict() T
}

// JSONEncoder is an interface that allows to implement custom JSON encoding
// algorithms.
// JSONMarshalFunc is an function type that allows to implement custom JSON
// encoding algorithms.
type JSONMarshalFunc func(v any) ([]byte, error)

// JSONUnmarshal is an interface that allows to implement custom JSON unmarshal
// algorithms.
// JSONUnmarshalFunc is an function type that allows to implement custom JSON
// unmarshal algorithms.
type JSONUnmarshalFunc func(data []byte, v any) error

type JSONDecoder interface {
Expand Down
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ var (
ErrTokenInvalidId = errors.New("token has invalid id")
ErrTokenInvalidClaims = errors.New("token has invalid claims")
ErrInvalidType = errors.New("invalid type for claim")
ErrUnsupported = errors.New("operation is unsupported")
)

// joinedError is an error type that works similar to what [errors.Join]
Expand Down
32 changes: 22 additions & 10 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ type Parser struct {
type decoders struct {
jsonUnmarshal JSONUnmarshalFunc
jsonNewDecoder JSONNewDecoderFunc[JSONDecoder]
base64Decode Base64DecodeFunc

// This field is disabled when using a custom base64 encoder.
decodeStrict bool
rawUrlBase64Encoding Base64Encoding
urlBase64Encoding Base64Encoding

// This field is disabled when using a custom base64 encoder.
decodeStrict bool
decodePaddingAllowed bool
}

Expand Down Expand Up @@ -227,22 +226,35 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke
// take into account whether the [Parser] is configured with additional options,
// such as [WithStrictDecoding] or [WithPaddingAllowed].
func (p *Parser) DecodeSegment(seg string) ([]byte, error) {
if p.base64Decode != nil {
return p.base64Decode(seg)
var encoding Base64Encoding
if p.rawUrlBase64Encoding != nil {
encoding = p.rawUrlBase64Encoding
} else {
encoding = base64.RawURLEncoding
}

encoding := base64.RawURLEncoding

if p.decodePaddingAllowed {
if l := len(seg) % 4; l > 0 {
seg += strings.Repeat("=", 4-l)
}
encoding = base64.URLEncoding

if p.urlBase64Encoding != nil {
encoding = p.urlBase64Encoding
} else {
encoding = base64.URLEncoding
}
}

if p.decodeStrict {
encoding = encoding.Strict()
// For now we can only support the standard library here because of the
// current state of the type parameter system
stricter, ok := encoding.(Stricter[*base64.Encoding])
if !ok {
return nil, newError("strict mode is only supported in encoding/base64", ErrUnsupported)
}
encoding = stricter.Strict()
}

return encoding.DecodeString(seg)
}

Expand Down
7 changes: 4 additions & 3 deletions parser_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,10 @@ func WithJSONDecoder[T JSONDecoder](f JSONUnmarshalFunc, f2 JSONNewDecoderFunc[T
}
}

// WithBase64Decoder supports a custom [Base64Decoder] to use in parsing the JWT.
func WithBase64Decoder(f Base64DecodeFunc) ParserOption {
// WithBase64Decoder supports a custom [Base64Encoding] to use in parsing the JWT.
func WithBase64Decoder(rawURL Base64Encoding, url Base64Encoding) ParserOption {
return func(p *Parser) {
p.base64Decode = f
p.rawUrlBase64Encoding = rawURL
p.urlBase64Encoding = url
}
}
2 changes: 1 addition & 1 deletion parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ var jwtTestData = []struct {
jwt.MapClaims{"foo": "bar"},
true,
nil,
jwt.NewParser(jwt.WithBase64Decoder(base64.RawURLEncoding.DecodeString)),
jwt.NewParser(jwt.WithBase64Decoder(base64.RawURLEncoding, base64.URLEncoding)),
jwt.SigningMethodRS256,
},
{
Expand Down
14 changes: 7 additions & 7 deletions token.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ type Token struct {
}

type encoders struct {
jsonMarshal JSONMarshalFunc // jsonEncoder is the custom json encoder/decoder
base64Encode Base64EncodeFunc // base64Encoder is the custom base64 encoder/decoder
jsonMarshal JSONMarshalFunc // jsonEncoder is the custom json encoder/decoder
base64Encoding Base64Encoding // base64Encoder is the custom base64 encoding
}

// New creates a new [Token] with the specified signing method and an empty map
Expand Down Expand Up @@ -114,12 +114,12 @@ func (t *Token) SigningString() (string, error) {
// [TokenOption]. Therefore, this function exists as a method of [Token], rather
// than a global function.
func (t *Token) EncodeSegment(seg []byte) string {
var enc Base64EncodeFunc
if t.base64Encode != nil {
enc = t.base64Encode
var enc Base64Encoding
if t.base64Encoding != nil {
enc = t.base64Encoding
} else {
enc = base64.RawURLEncoding.EncodeToString
enc = base64.RawURLEncoding
}

return enc(seg)
return enc.EncodeToString(seg)
}
4 changes: 2 additions & 2 deletions token_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ func WithJSONEncoder(f JSONMarshalFunc) TokenOption {
}
}

func WithBase64Encoder(f Base64EncodeFunc) TokenOption {
func WithBase64Encoder(enc Base64Encoding) TokenOption {
return func(token *Token) {
token.base64Encode = f
token.base64Encoding = enc
}
}
2 changes: 1 addition & 1 deletion token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestToken_SigningString(t1 *testing.T) {
Valid: false,
Options: []jwt.TokenOption{
jwt.WithJSONEncoder(json.Marshal),
jwt.WithBase64Encoder(base64.StdEncoding.EncodeToString),
jwt.WithBase64Encoder(base64.StdEncoding),
},
},
want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30",
Expand Down

0 comments on commit 7684d3e

Please sign in to comment.