diff --git a/README.md b/README.md index 43deddb9..af31e8eb 100644 --- a/README.md +++ b/README.md @@ -3,202 +3,7 @@ **WARNING** This `v2` branch is not production ready - use at your own risk. -**NOTE:** We released this version using a fork of jwt-go in order to address a security vulnerability. Due to jwt-go not being actively maintained we will be looking to switch to a more actively maintained package in the near future. - -A middleware that will check that a [JWT](http://jwt.io/) is sent on the `Authorization` header and will then set the content of the JWT into the `user` variable of the request. - -This module lets you authenticate HTTP requests using JWT tokens in your Go Programming Language applications. JWTs are typically used to protect API endpoints, and are often issued using OpenID Connect. - -## Key Features - -* Ability to **check the `Authorization` header for a JWT** -* **Decode the JWT** and set the content of it to the request context - -## Installing - -````bash -go get github.com/auth0/go-jwt-middleware -```` - -## Using it - -You can use `jwtmiddleware` with default `net/http` as follows. - -````go -// main.go -package main - -import ( - "fmt" - "net/http" - - "github.com/auth0/go-jwt-middleware" - "github.com/form3tech-oss/jwt-go" -) - -var myHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user") - fmt.Fprintf(w, "This is an authenticated request") - fmt.Fprintf(w, "Claim content:\n") - for k, v := range user.(*jwt.Token).Claims.(jwt.MapClaims) { - fmt.Fprintf(w, "%s :\t%#v\n", k, v) - } -}) - -func main() { - jwtMiddleware := jwtmiddleware.New(jwtmiddleware.Options{ - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - return []byte("My Secret"), nil - }, - // When set, the middleware verifies that tokens are signed with the specific signing algorithm - // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks - // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ - SigningMethod: jwt.SigningMethodHS256, - }) - - app := jwtMiddleware.Handler(myHandler) - http.ListenAndServe("0.0.0.0:3000", app) -} -```` - -You can also use it with Negroni as follows: - -````go -// main.go -package main - -import ( - "encoding/json" - "net/http" - - "github.com/auth0/go-jwt-middleware" - "github.com/form3tech-oss/jwt-go" - "github.com/gorilla/mux" - "github.com/urfave/negroni" -) - -func main() { - StartServer() -} - -func StartServer() { - r := mux.NewRouter() - - jwtMiddleware := jwtmiddleware.New(jwtmiddleware.Options{ - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - return []byte("My Secret"), nil - }, - SigningMethod: jwt.SigningMethodHS256, - }) - - r.HandleFunc("/ping", PingHandler) - r.Handle("/secured/ping", negroni.New( - negroni.HandlerFunc(jwtMiddleware.HandlerWithNext), - negroni.Wrap(http.HandlerFunc(SecuredPingHandler)), - )) - http.Handle("/", r) - http.ListenAndServe(":3001", nil) -} - -type Response struct { - Text string `json:"text"` -} - -func respondJSON(text string, w http.ResponseWriter) { - response := Response{text} - - jsonResponse, err := json.Marshal(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - w.Write(jsonResponse) -} - -func PingHandler(w http.ResponseWriter, r *http.Request) { - respondJSON("All good. You don't need to be authenticated to call this", w) -} - -func SecuredPingHandler(w http.ResponseWriter, r *http.Request) { - respondJSON("All good. You only get this message if you're authenticated", w) -} -```` - -## Options - -````go -// Options is a struct for specifying configuration options for the middleware. -type Options struct { - // The function that will return the Key to validate the JWT. - // It can be either a shared secret or a public key. - // Default value: nil - ValidationKeyGetter jwt.Keyfunc - // The name of the property in the request where the user information - // from the JWT will be stored. - // Default value: "user" - UserProperty string - // The function that will be called when there's an error validating the token - // Default value: - ErrorHandler errorHandler - // A boolean indicating if the credentials are required or not - // Default value: false - CredentialsOptional bool - // A function that extracts the token from the request - // Default: FromAuthHeader (i.e., from Authorization header as bearer token) - Extractor TokenExtractor - // Debug flag turns on debugging output - // Default: false - Debug bool - // When set, all requests with the OPTIONS method will use authentication - // Default: false - EnableAuthOnOptions bool - // When set, the middelware verifies that tokens are signed with the specific signing algorithm - // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks - // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ - // Default: nil - SigningMethod jwt.SigningMethod -} -```` - -### Token Extraction - -The default value for the `Extractor` option is the `FromAuthHeader` -function which assumes that the JWT will be provided as a bearer token -in an `Authorization` header, i.e., - -``` -Authorization: bearer {token} -``` - -To extract the token from a query string parameter, you can use the -`FromParameter` function, e.g., - -```go -jwtmiddleware.New(jwtmiddleware.Options{ - Extractor: jwtmiddleware.FromParameter("auth_code"), -}) -``` - -In this case, the `FromParameter` function will look for a JWT in the -`auth_code` query parameter. - -Or, if you want to allow both, you can use the `FromFirst` function to -try and extract the token first in one way and then in one or more -other ways, e.g., - -```go -jwtmiddleware.New(jwtmiddleware.Options{ - Extractor: jwtmiddleware.FromFirst(jwtmiddleware.FromAuthHeader, - jwtmiddleware.FromParameter("auth_code")), -}) -``` - -## Examples - -You can check out working examples in the [examples folder](https://github.com/auth0/go-jwt-middleware/tree/master/examples) - +TODO: update this README in the `v2` branch. We're waiting so as not to hold everything up in the testing branch. Also some of the default validation logic needs to be added here. ## What is Auth0? diff --git a/examples/http-example/main.go b/examples/http-example/main.go index 79724ae8..a1a18efc 100644 --- a/examples/http-example/main.go +++ b/examples/http-example/main.go @@ -8,6 +8,26 @@ import ( "github.com/form3tech-oss/jwt-go" ) +// TODO: replace this with default validate token func once it is merged in +func REPLACE_ValidateToken(token string) (interface{}, error) { + // Now parse the token + parsedToken, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) { + return []byte("My Secret"), nil + }) + + // Check if there was an error in parsing... + if err != nil { + return nil, err + } + + // Check if the parsed token is valid... + if !parsedToken.Valid { + return nil, jwtmiddleware.ErrJWTInvalid + } + + return parsedToken, nil +} + var myHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := r.Context().Value("user") fmt.Fprintf(w, "This is an authenticated request") @@ -18,16 +38,7 @@ var myHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { }) func main() { - jwtMiddleware := jwtmiddleware.New(jwtmiddleware.Options{ - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - return []byte("My Secret"), nil - }, - // When set, the middleware verifies that tokens are signed with the specific signing algorithm - // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks - // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ - SigningMethod: jwt.SigningMethodHS256, - }) + jwtMiddleware := jwtmiddleware.New(jwtmiddleware.WithValidateToken(REPLACE_ValidateToken)) - app := jwtMiddleware.Handler(myHandler) - http.ListenAndServe("0.0.0.0:3000", app) + http.ListenAndServe("0.0.0.0:3000", jwtMiddleware.CheckJWT(myHandler)) } diff --git a/go.mod b/go.mod index 5ea135e5..cb6ed17f 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,6 @@ module github.com/auth0/go-jwt-middleware go 1.14 require ( - github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0 // indirect github.com/form3tech-oss/jwt-go v3.2.2+incompatible - github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab - github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 // indirect - github.com/gorilla/mux v1.7.4 - github.com/smartystreets/assertions v1.1.0 // indirect - github.com/smartystreets/goconvey v1.6.4 - github.com/urfave/negroni v1.0.0 + github.com/google/go-cmp v0.5.5 ) diff --git a/go.sum b/go.sum index 736f537d..11890ce3 100644 --- a/go.sum +++ b/go.sum @@ -1,27 +1,6 @@ -github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0 h1:sDMmm+q/3+BukdIpxwO365v/Rbspp2Nt5XntgQRXq8Q= -github.com/codegangsta/inject v0.0.0-20150114235600-33e0aa1cb7c0/go.mod h1:4Zcjuz89kmFXt9morQgcfYZAYZ5n8WHjt81YYWIwtTM= github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= -github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab h1:xveKWz2iaueeTaUgdetzel+U7exyigDYBryyVfV/rZk= -github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AEU963A2AYjv4d1V5eVL1CQbEJq6aCNHDDjibzu8= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0= -github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc= -github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= -github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/assertions v1.1.0 h1:MkTeG1DMwsrdH7QtLXy5W+fUxWq+vmb6cLmyJ7aRtF0= -github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= -github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= -github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= -github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc= -github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/jwtmiddleware.go b/jwtmiddleware.go index 16801bf6..9c0a0ca2 100644 --- a/jwtmiddleware.go +++ b/jwtmiddleware.go @@ -4,13 +4,51 @@ import ( "context" "errors" "fmt" - "log" "net/http" "strings" ) -// A function called whenever an error is encountered -type errorHandler func(w http.ResponseWriter, r *http.Request, err string) +var ( + ErrJWTMissing = errors.New("jwt missing") + ErrJWTInvalid = errors.New("jwt invalid") +) + +// ContextKey is the key used in the request context where the information +// from a validated JWT will be stored. +type ContextKey struct{} + +// invalidError handles wrapping a JWT validation error with the concrete error +// ErrJWTInvalid. We do not expose this publicly because the interface methods +// of Is and Unwrap should give the user all they need. +type invalidError struct { + details error +} + +// Is allows the error to support equality to ErrJWTInvalid. +func (e *invalidError) Is(target error) bool { + return target == ErrJWTInvalid +} + +func (e *invalidError) Error() string { + return fmt.Sprintf("%s: %s", ErrJWTInvalid, e.details) +} + +// Unwrap allows the error to support equality to the underlying error and not +// just ErrJWTInvalid. +func (e *invalidError) Unwrap() error { + return e.details +} + +// ErrorHandler is a handler which is called when an error occurs in the +// middleware. Among some general errors, this handler also determines the +// response of the middleware when a token is not found or is invalid. The err +// can be checked to be ErrJWTMissing or ErrJWTInvalid for specific cases. The +// default handler will return a status code of 400 for ErrJWTMissing, 401 for +// ErrJWTInvalid, and 500 for all other errors. If you implement your own +// ErrorHandler you MUST take into consideration the error types as not +// properly responding to them or having a poorly implemented handler could +// result in the middleware not functioning as intended. +type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error) // TokenExtractor is a function that takes a request as input and returns // either a token or an error. An error should only be returned if an attempt @@ -26,107 +64,101 @@ type TokenExtractor func(r *http.Request) (string, error) // happen. In the default implementation we can add safe defaults for those. type ValidateToken func(string) (interface{}, error) -// Options is a struct for specifying configuration options for the middleware. -type Options struct { - // Validate handles validating a token. - Validate ValidateToken - // The name of the property in the request where the user information - // from the JWT will be stored. - // Default value: "user" - UserProperty string - // The function that will be called when there are errors in the - // middleware. - // Default value: OnError - ErrorHandler errorHandler - // A boolean indicating if the credentials are required or not - // Default value: false - CredentialsOptional bool - // A function that extracts the token from the request - // Default: FromAuthHeader (i.e., from Authorization header as bearer token) - Extractor TokenExtractor - // Debug flag turns on debugging output - // Default: false - Debug bool - // When set, all requests with the OPTIONS method will use authentication - // Default: false - EnableAuthOnOptions bool -} - type JWTMiddleware struct { - Options Options + validateToken ValidateToken + errorHandler ErrorHandler + credentialsOptional bool + tokenExtractor TokenExtractor + validateOnOptions bool } -func OnError(w http.ResponseWriter, r *http.Request, err string) { - http.Error(w, err, http.StatusUnauthorized) -} - -// New constructs a new Secure instance with supplied options. -func New(options ...Options) *JWTMiddleware { - - var opts Options - if len(options) == 0 { - opts = Options{} - } else { - opts = options[0] - } +// Option is how options for the middleware are setup. +type Option func(*JWTMiddleware) - if opts.UserProperty == "" { - opts.UserProperty = "user" +// WithValidateToken sets up the function to be used to validate all tokens. +// See the ValidateToken type for more information. +// Default: TODO: after merge into `v2` +func WithValidateToken(vt ValidateToken) Option { + return func(m *JWTMiddleware) { + m.validateToken = vt } +} - if opts.ErrorHandler == nil { - opts.ErrorHandler = OnError +// WithErrorHandler sets the handler which is called when there are errors in +// the middleware. See the ErrorHandler type for more information. +// Default value: DefaultErrorHandler +func WithErrorHandler(h ErrorHandler) Option { + return func(m *JWTMiddleware) { + m.errorHandler = h } +} - if opts.Extractor == nil { - opts.Extractor = FromAuthHeader +// WithCredentialsOptional sets up if credentials are optional or not. If set +// to true then an empty token will be considered valid. +// Default value: false +func WithCredentialsOptional(value bool) Option { + return func(m *JWTMiddleware) { + m.credentialsOptional = value } +} - return &JWTMiddleware{ - Options: opts, +// WithTokenExtractor sets up the function which extracts the JWT to be +// validated from the request. +// Default: AuthHeaderTokenExtractor +func WithTokenExtractor(e TokenExtractor) Option { + return func(m *JWTMiddleware) { + m.tokenExtractor = e } } -func (m *JWTMiddleware) logf(format string, args ...interface{}) { - if m.Options.Debug { - log.Printf(format, args...) +// WithValidateOnOptions sets up if OPTIONS requests should have their JWT +// validated or not. +// Default: true +func WithValidateOnOptions(value bool) Option { + return func(m *JWTMiddleware) { + m.validateOnOptions = value } } -// HandlerWithNext is a special implementation for Negroni, but could be used elsewhere. -func (m *JWTMiddleware) HandlerWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - err := m.CheckJWT(w, r) - - // If there was an error, do not call next. - if err == nil && next != nil { - next(w, r) +// New constructs a new JWTMiddleware instance with the supplied options. +func New(opts ...Option) *JWTMiddleware { + m := &JWTMiddleware{ + validateToken: func(string) (interface{}, error) { panic("not implemented") }, + errorHandler: DefaultErrorHandler, + credentialsOptional: false, + tokenExtractor: AuthHeaderTokenExtractor, + validateOnOptions: true, } -} -func (m *JWTMiddleware) Handler(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Let secure process the request. If it returns an error, - // that indicates the request should not continue. - err := m.CheckJWT(w, r) + for _, opt := range opts { + opt(m) + } - // If there was an error, do not continue. - if err != nil { - return - } + return m +} - h.ServeHTTP(w, r) - }) +// DefaultErrorHandler is the default error handler implementation for the +// middleware. If an error handler is not provided via the WithErrorHandler +// option this will be used. +func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) { + switch { + case errors.Is(err, ErrJWTMissing): + w.WriteHeader(http.StatusBadRequest) + case errors.Is(err, ErrJWTInvalid): + w.WriteHeader(http.StatusUnauthorized) + default: + w.WriteHeader(http.StatusInternalServerError) + } } -// FromAuthHeader is a "TokenExtractor" that takes a give request and extracts -// the JWT token from the Authorization header. -func FromAuthHeader(r *http.Request) (string, error) { +// AuthHeaderTokenExtractor is a TokenExtractor that takes a request and +// extracts the token from the Authorization header. +func AuthHeaderTokenExtractor(r *http.Request) (string, error) { authHeader := r.Header.Get("Authorization") if authHeader == "" { - return "", nil // No error, just no token + return "", nil // No error, just no JWT } - // TODO: Make this a bit more robust, parsing-wise authHeaderParts := strings.Fields(authHeader) if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { return "", errors.New("Authorization header format must be Bearer {token}") @@ -135,17 +167,19 @@ func FromAuthHeader(r *http.Request) (string, error) { return authHeaderParts[1], nil } -// FromParameter returns a function that extracts the token from the specified -// query string parameter -func FromParameter(param string) TokenExtractor { +// ParameterTokenExtractor returns a TokenExtractor that extracts the token +// from the specified query string parameter +func ParameterTokenExtractor(param string) TokenExtractor { return func(r *http.Request) (string, error) { return r.URL.Query().Get(param), nil } } -// FromFirst returns a function that runs multiple token extractors and takes the -// first token it finds -func FromFirst(extractors ...TokenExtractor) TokenExtractor { +// MultiTokenExtractor returns a TokenExtractor that runs multiple +// TokenExtractors and takes the TokenExtractor that does not return an empty +// token. If a TokenExtractor returns an error that error is immediately +// returned. +func MultiTokenExtractor(extractors ...TokenExtractor) TokenExtractor { return func(r *http.Request) (string, error) { for _, ex := range extractors { token, err := ex(r) @@ -160,59 +194,50 @@ func FromFirst(extractors ...TokenExtractor) TokenExtractor { } } -func (m *JWTMiddleware) CheckJWT(w http.ResponseWriter, r *http.Request) error { - if !m.Options.EnableAuthOnOptions { - if r.Method == "OPTIONS" { - return nil +// CheckJWT is the main middleware function which performs the main logic. It +// is passed an http.Handler which will be called if the JWT passes validation. +func (m *JWTMiddleware) CheckJWT(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // if we don't validate on OPTIONS and this is OPTIONS then + // continue onto next without validating + if !m.validateOnOptions && r.Method == http.MethodOptions { + next.ServeHTTP(w, r) + return } - } - - // Use the specified token extractor to extract a token from the request - token, err := m.Options.Extractor(r) - - // If debugging is turned on, log the outcome - if err != nil { - m.logf("Error extracting JWT: %v", err) - } else { - m.logf("Token extracted: %s", token) - } - - // If an error occurs, call the error handler and return an error - if err != nil { - m.Options.ErrorHandler(w, r, err.Error()) - return fmt.Errorf("Error extracting token: %w", err) - } - // If the token is empty... - if token == "" { - // Check if it was required - if m.Options.CredentialsOptional { - m.logf(" No credentials found (CredentialsOptional=true)") - // No error, just no token (and that is ok given that CredentialsOptional is true) - return nil + token, err := m.tokenExtractor(r) + if err != nil { + // this is not ErrJWTMissing because an error here means that + // the tokenExtractor had an error and _not_ that the token was + // missing. + m.errorHandler(w, r, fmt.Errorf("error extracting token: %w", err)) + return } - // If we get here, the required token is missing - errorMsg := "Required authorization token not found" - m.Options.ErrorHandler(w, r, errorMsg) - m.logf(" Error: No credentials found (CredentialsOptional=false)") - return fmt.Errorf(errorMsg) - } - - validToken, err := m.Options.Validate(token) + if token == "" { + // if credentials are optional continue onto next + // without validating + if m.credentialsOptional { + next.ServeHTTP(w, r) + return + } - if err != nil { - m.logf("Token is invalid") - m.Options.ErrorHandler(w, r, "The token isn't valid") - return err - } + // credentials were not optional so we error + m.errorHandler(w, r, ErrJWTMissing) + return + } - m.logf("JWT: %v", validToken) + // validate the token using the token validator + validToken, err := m.validateToken(token) + if err != nil { + m.errorHandler(w, r, &invalidError{details: err}) + return + } - // If we get here, everything worked and we can set the - // user property in context. - newRequest := r.WithContext(context.WithValue(r.Context(), m.Options.UserProperty, validToken)) - // Update the current request with the new context information. - *r = *newRequest - return nil + // no err means we have a valid token, so set it into the + // context and continue onto next + newRequest := r.WithContext(context.WithValue(r.Context(), ContextKey{}, validToken)) + r = newRequest + next.ServeHTTP(w, r) + }) } diff --git a/jwtmiddleware_test.go b/jwtmiddleware_test.go index cab5cd51..1012bd13 100644 --- a/jwtmiddleware_test.go +++ b/jwtmiddleware_test.go @@ -1,216 +1,319 @@ package jwtmiddleware import ( - "encoding/json" + "errors" "fmt" "io/ioutil" "net/http" "net/http/httptest" - "strings" + "net/url" "testing" - "github.com/form3tech-oss/jwt-go" - "github.com/gorilla/mux" - . "github.com/smartystreets/goconvey/convey" - "github.com/urfave/negroni" + "github.com/google/go-cmp/cmp" ) -// defaultAuthorizationHeaderName is the default header name where the Auth -// token should be written -const defaultAuthorizationHeaderName = "Authorization" +// defaults tests against the default setup +// TODO(joncarl): replace with actual JWTs once we have the validate stuff plumbed in +func Test_defaults(t *testing.T) { + tests := []struct { + name string + options []Option + method string + token string -// userPropertyName is the property name that will be set in the request context -const userPropertyName = "custom-user-property" + wantToken map[string]string + wantStatusCode int + wantBody string + }{ + { + name: "happy path", + options: []Option{WithValidateToken(func(token string) (interface{}, error) { + return map[string]string{"foo": "bar"}, nil + })}, + token: "bearer abc", + wantToken: map[string]string{"foo": "bar"}, + wantStatusCode: http.StatusOK, + wantBody: "authenticated", + }, + { + name: "validate on options", + options: []Option{WithValidateToken(func(token string) (interface{}, error) { + return map[string]string{"foo": "bar"}, nil + })}, + method: http.MethodOptions, + token: "bearer abc", + wantToken: map[string]string{"foo": "bar"}, + wantStatusCode: http.StatusOK, + wantBody: "authenticated", + }, + { + name: "bad token format", + options: []Option{WithValidateToken(func(token string) (interface{}, error) { + return map[string]string{"foo": "bar"}, nil + })}, + token: "abc", + wantStatusCode: http.StatusInternalServerError, + }, + { + name: "credentials not optional", + options: []Option{WithValidateToken(func(token string) (interface{}, error) { + return map[string]string{"foo": "bar"}, nil + })}, + token: "", + wantStatusCode: http.StatusBadRequest, + }, + { + name: "validate token errors", + options: []Option{WithValidateToken(func(token string) (interface{}, error) { + return nil, errors.New("validate token error") + })}, + token: "bearer abc", + wantStatusCode: http.StatusUnauthorized, + }, + { + name: "validateOnOptions set to false", + options: []Option{ + WithValidateOnOptions(false), + WithValidateToken(func(token string) (interface{}, error) { + return nil, errors.New("should not hit me since we are not validating on options") + }), + }, + method: http.MethodOptions, + token: "bearer abc", + wantStatusCode: http.StatusOK, + wantBody: "authenticated", + }, + { + name: "tokenExtractor errors", + options: []Option{WithTokenExtractor(func(r *http.Request) (string, error) { + return "", errors.New("token extractor error") + })}, + wantStatusCode: http.StatusInternalServerError, + }, + { + name: "credentialsOptional true", + options: []Option{ + WithCredentialsOptional(true), + WithTokenExtractor(func(r *http.Request) (string, error) { + return "", nil + }), + WithValidateToken(func(token string) (interface{}, error) { + return nil, errors.New("should not hit me since credentials are optional and there are none") + }), + }, + wantStatusCode: http.StatusOK, + wantBody: "authenticated", + }, + { + name: "credentialsOptional false", + options: []Option{ + WithCredentialsOptional(false), + WithTokenExtractor(func(r *http.Request) (string, error) { + return "", nil + }), + WithValidateToken(func(token string) (interface{}, error) { + return nil, errors.New("should not hit me since ErrJWTMissing should be returned") + }), + }, + wantStatusCode: http.StatusBadRequest, + }, + } -// the bytes read from the keys/sample-key file -// private key generated with http://kjur.github.io/jsjws/tool_jwt.html -var privateKey []byte + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var actualContextToken map[string]string -// TestUnauthenticatedRequest will perform requests with no Authorization header -func TestUnauthenticatedRequest(t *testing.T) { - Convey("Simple unauthenticated request", t, func() { - Convey("Unauthenticated GET to / path should return a 200 response", func() { - w := makeUnauthenticatedRequest("GET", "/") - So(w.Code, ShouldEqual, http.StatusOK) - }) - Convey("Unauthenticated GET to /protected path should return a 401 response", func() { - w := makeUnauthenticatedRequest("GET", "/protected") - So(w.Code, ShouldEqual, http.StatusUnauthorized) - }) - }) -} + if tc.method == "" { + tc.method = http.MethodGet + } -// TestAuthenticatedRequest will perform requests with an Authorization header -func TestAuthenticatedRequest(t *testing.T) { - var e error - privateKey, e = readPrivateKey() - if e != nil { - panic(e) - } - Convey("Simple authenticated requests", t, func() { - Convey("Authenticated GET to / path should return a 200 response", func() { - w := makeAuthenticatedRequest("GET", "/", jwt.MapClaims{"foo": "bar"}, nil) - So(w.Code, ShouldEqual, http.StatusOK) - }) - Convey("Authenticated GET to /protected path should return a 200 response if expected algorithm is not specified", func() { - var expectedAlgorithm jwt.SigningMethod = nil - w := makeAuthenticatedRequest("GET", "/protected", jwt.MapClaims{"foo": "bar"}, expectedAlgorithm) - So(w.Code, ShouldEqual, http.StatusOK) - responseBytes, err := ioutil.ReadAll(w.Body) - if err != nil { - panic(err) + m := New(tc.options...) + ts := httptest.NewServer(m.CheckJWT(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if ctxToken, ok := r.Context().Value(ContextKey{}).(map[string]string); ok { + actualContextToken = ctxToken + } + fmt.Fprint(w, "authenticated") + }))) + defer ts.Close() + + client := ts.Client() + req, _ := http.NewRequest(tc.method, ts.URL, nil) + + if len(tc.token) > 0 { + req.Header.Add("Authorization", tc.token) } - responseString := string(responseBytes) - // check that the encoded data in the jwt was properly returned as json - So(responseString, ShouldEqual, `{"text":"bar"}`) - }) - Convey("Authenticated GET to /protected path should return a 200 response if expected algorithm is correct", func() { - expectedAlgorithm := jwt.SigningMethodHS256 - w := makeAuthenticatedRequest("GET", "/protected", jwt.MapClaims{"foo": "bar"}, expectedAlgorithm) - So(w.Code, ShouldEqual, http.StatusOK) - responseBytes, err := ioutil.ReadAll(w.Body) + + res, err := client.Do(req) if err != nil { - panic(err) + t.Fatal(err) } - responseString := string(responseBytes) - // check that the encoded data in the jwt was properly returned as json - So(responseString, ShouldEqual, `{"text":"bar"}`) - }) - Convey("Authenticated GET to /protected path should return a 401 response if algorithm is not expected one", func() { - expectedAlgorithm := jwt.SigningMethodRS256 - w := makeAuthenticatedRequest("GET", "/protected", jwt.MapClaims{"foo": "bar"}, expectedAlgorithm) - So(w.Code, ShouldEqual, http.StatusUnauthorized) - responseBytes, err := ioutil.ReadAll(w.Body) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() if err != nil { - panic(err) + t.Fatal(err) + } + + if want, got := tc.wantStatusCode, res.StatusCode; want != got { + t.Fatalf("want status code %d, got %d", want, got) + } + + if want, got := tc.wantBody, string(body); !cmp.Equal(want, got) { + t.Fatal(cmp.Diff(want, got)) + } + + if want, got := tc.wantToken, actualContextToken; !cmp.Equal(want, got) { + t.Fatal(cmp.Diff(want, got)) } - responseString := string(responseBytes) - // check that the encoded data in the jwt was properly returned as json - So(strings.TrimSpace(responseString), ShouldEqual, "Expected RS256 signing method but token specified HS256") }) - }) -} + } -func makeUnauthenticatedRequest(method string, url string) *httptest.ResponseRecorder { - return makeAuthenticatedRequest(method, url, nil, nil) } -func makeAuthenticatedRequest(method string, url string, c jwt.Claims, expectedSignatureAlgorithm jwt.SigningMethod) *httptest.ResponseRecorder { - r, _ := http.NewRequest(method, url, nil) - if c != nil { - token := jwt.New(jwt.SigningMethodHS256) - token.Claims = c - // private key generated with http://kjur.github.io/jsjws/tool_jwt.html - s, e := token.SignedString(privateKey) - if e != nil { - panic(e) +func Test_invalidError(t *testing.T) { + t.Run("Is", func(t *testing.T) { + e := invalidError{details: errors.New("error details")} + + if !errors.Is(&e, ErrJWTInvalid) { + t.Fatal("expected invalidError to be ErrJWTInvalid via errors.Is, but it was not") } - r.Header.Set(defaultAuthorizationHeaderName, fmt.Sprintf("bearer %v", s)) - } - w := httptest.NewRecorder() - n := createNegroniMiddleware(expectedSignatureAlgorithm) - n.ServeHTTP(w, r) - return w -} + }) -func createNegroniMiddleware(expectedSignatureAlgorithm jwt.SigningMethod) *negroni.Negroni { - // create a gorilla mux router for public requests - publicRouter := mux.NewRouter().StrictSlash(true) - publicRouter.Methods("GET"). - Path("/"). - Name("Index"). - Handler(http.HandlerFunc(indexHandler)) - - // create a gorilla mux route for protected requests - // the routes will be tested for jwt tokens in the default auth header - protectedRouter := mux.NewRouter().StrictSlash(true) - protectedRouter.Methods("GET"). - Path("/protected"). - Name("Protected"). - Handler(http.HandlerFunc(protectedHandler)) - // create a negroni handler for public routes - negPublic := negroni.New() - negPublic.UseHandler(publicRouter) - - // negroni handler for api request - negProtected := negroni.New() - //add the JWT negroni handler - negProtected.Use(negroni.HandlerFunc(JWT(expectedSignatureAlgorithm).HandlerWithNext)) - negProtected.UseHandler(protectedRouter) - - //Create the main router - mainRouter := mux.NewRouter().StrictSlash(true) - - mainRouter.Handle("/", negPublic) - mainRouter.Handle("/protected", negProtected) - //if routes match the handle prefix then I need to add this dummy matcher {_dummy:.*} - mainRouter.Handle("/protected/{_dummy:.*}", negProtected) - - n := negroni.Classic() - // This are the "GLOBAL" middlewares that will be applied to every request - // examples are listed below: - //n.Use(gzip.Gzip(gzip.DefaultCompression)) - //n.Use(negroni.HandlerFunc(SecurityMiddleware().HandlerFuncWithNext)) - n.UseHandler(mainRouter) - - return n -} + t.Run("Error", func(t *testing.T) { + e := invalidError{details: errors.New("error details")} -// JWT creates the middleware that parses a JWT encoded token -func JWT(expectedSignatureAlgorithm jwt.SigningMethod) *JWTMiddleware { - return New(Options{ - Debug: false, - CredentialsOptional: false, - UserProperty: userPropertyName, - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - if privateKey == nil { - var err error - privateKey, err = readPrivateKey() - if err != nil { - panic(err) - } - } - return privateKey, nil - }, - SigningMethod: expectedSignatureAlgorithm, + mustErrorMsg(t, "jwt invalid: error details", &e) }) -} -// readPrivateKey will load the keys/sample-key file into the -// global privateKey variable -func readPrivateKey() ([]byte, error) { - privateKey, e := ioutil.ReadFile("keys/sample-key") - return privateKey, e -} + t.Run("Unwrap", func(t *testing.T) { + expectedErr := errors.New("expected err") + e := invalidError{details: expectedErr} -// indexHandler will return an empty 200 OK response -func indexHandler(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) + // under the hood errors.Is is unwrapping the invalidError via + // Unwrap(). + if !errors.Is(&e, expectedErr) { + t.Fatal("expected invalidError to be expectedErr via errors.Is, but it was not") + } + }) } -// protectedHandler will return the content of the "foo" encoded data -// in the token as json -> {"text":"bar"} -func protectedHandler(w http.ResponseWriter, r *http.Request) { - // retrieve the token from the context - u := r.Context().Value(userPropertyName) - user := u.(*jwt.Token) - respondJSON(user.Claims.(jwt.MapClaims)["foo"].(string), w) +func Test_MultiTokenExtractor(t *testing.T) { + t.Run("uses first extractor that replies", func(t *testing.T) { + wantToken := "i am token" + + exNothing := func(r *http.Request) (string, error) { + return "", nil + } + exSomething := func(r *http.Request) (string, error) { + return wantToken, nil + } + exFail := func(r *http.Request) (string, error) { + return "", errors.New("should not have hit me") + } + + ex := MultiTokenExtractor(exNothing, exSomething, exFail) + + gotToken, err := ex(&http.Request{}) + mustErrorMsg(t, "", err) + + if wantToken != gotToken { + t.Fatalf("wanted token: %q, got: %q", wantToken, gotToken) + } + }) + + t.Run("stops when an extractor fails", func(t *testing.T) { + wantErr := "extraction fail" + + exNothing := func(r *http.Request) (string, error) { + return "", nil + } + exFail := func(r *http.Request) (string, error) { + return "", errors.New(wantErr) + } + + ex := MultiTokenExtractor(exNothing, exFail) + + gotToken, err := ex(&http.Request{}) + mustErrorMsg(t, wantErr, err) + + if gotToken != "" { + t.Fatalf("did not want a token but got: %q", gotToken) + } + }) + + t.Run("defaults to empty", func(t *testing.T) { + exNothing := func(r *http.Request) (string, error) { + return "", nil + } + + ex := MultiTokenExtractor(exNothing, exNothing, exNothing) + + gotToken, err := ex(&http.Request{}) + mustErrorMsg(t, "", err) + + if "" != gotToken { + t.Fatalf("wanted empty token but got: %q", gotToken) + } + }) } -// Response quick n' dirty Response struct to be encoded as json -type Response struct { - Text string `json:"text"` +func Test_ParameterTokenExtractor(t *testing.T) { + wantToken := "i am token" + param := "i-am-param" + + u, err := url.Parse(fmt.Sprintf("http://localhost?%s=%s", param, wantToken)) + mustErrorMsg(t, "", err) + r := &http.Request{URL: u} + + ex := ParameterTokenExtractor(param) + + gotToken, err := ex(r) + mustErrorMsg(t, "", err) + + if wantToken != gotToken { + t.Fatalf("wanted token: %q, got: %q", wantToken, gotToken) + } } -// respondJSON will take an string to write through the writer as json -func respondJSON(text string, w http.ResponseWriter) { - response := Response{text} +func Test_AuthHeaderTokenExtractor(t *testing.T) { + tests := []struct { + name string + request *http.Request + wantToken string + wantError string + }{ + { + name: "empty / no header", + request: &http.Request{}, + }, + { + name: "token in header", + request: &http.Request{Header: http.Header{"Authorization": []string{fmt.Sprintf("Bearer %s", "i-am-token")}}}, + wantToken: "i-am-token", + }, + { + name: "no bearer", + request: &http.Request{Header: http.Header{"Authorization": []string{"i-am-token"}}}, + wantError: "Authorization header format must be Bearer {token}", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + gotToken, gotError := AuthHeaderTokenExtractor(tc.request) + mustErrorMsg(t, tc.wantError, gotError) + + if tc.wantToken != gotToken { + t.Fatalf("wanted token: %q, got: %q", tc.wantToken, gotToken) + } + + }) + } +} - jsonResponse, err := json.Marshal(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return +func mustErrorMsg(t testing.TB, want string, got error) { + if (want == "" && got != nil) || + (want != "" && (got == nil || got.Error() != want)) { + t.Fatalf("want error: %s, got %v", want, got) } - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write(jsonResponse) }