diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000..5c10616d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: Auth0 Community + url: https://community.auth0.com/c/sdks/5 + about: Discuss this SDK in the Auth0 Community forums + - name: SDK API Documentation + url: https://pkg.go.dev/github.com/auth0/go-jwt-middleware + about: Read the API documentation for this SDK diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000..4d7ef3f7 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,35 @@ +**Please do not report security vulnerabilities here**. +The [Responsible Disclosure Program](https://auth0.com/whitehat) details the procedure for disclosing security issues. + +**Thank you in advance for helping us to improve this library!** +Your attention to detail here is greatly appreciated and will help us respond as quickly as possible. +For general support or usage questions, use the [Auth0 Community](https://community.auth0.com/) or +[Auth0 Support](https://support.auth0.com/). +Finally, to avoid duplicates, please search existing Issues before submitting one here. + +By submitting an Issue to this repository, you agree to the terms within the +[Auth0 Code of Conduct](https://github.com/auth0/open-source-template/blob/master/CODE-OF-CONDUCT.md). + +### Describe the problem you'd like to have solved + + + +### Describe the ideal solution + + + +## Alternatives and current workarounds + + + +### Additional context + + diff --git a/.github/ISSUE_TEMPLATE/report_a_bug.md b/.github/ISSUE_TEMPLATE/report_a_bug.md new file mode 100644 index 00000000..ecf852e2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/report_a_bug.md @@ -0,0 +1,55 @@ +**Please do not report security vulnerabilities here**. +The [Responsible Disclosure Program](https://auth0.com/whitehat) details the procedure for disclosing security issues. + +**Thank you in advance for helping us to improve this library!** +Your attention to detail here is greatly appreciated and will help us respond as quickly as possible. +For general support or usage questions, use the [Auth0 Community](https://community.auth0.com/) or +[Auth0 Support](https://support.auth0.com/). +Finally, to avoid duplicates, please search existing Issues before submitting one here. + +By submitting an Issue to this repository, you agree to the terms within the +[Auth0 Code of Conduct](https://github.com/auth0/open-source-template/blob/master/CODE-OF-CONDUCT.md). + +### Describe the problem + + + +### What was the expected behavior? + + + +### Reproduction + + + +### Environment + + diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..7f4f6268 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,53 @@ +## Description + + + + +## References + + + + +## Testing + + + +- [ ] This change adds test coverage for new/changed/fixed functionality + + +## Checklist + + + +- [x] I have read and agreed to the terms within the [Auth0 Code of Conduct](https://github.com/auth0/open-source-template/blob/master/CODE-OF-CONDUCT.md). +- [x] I have read the [Auth0 General Contribution Guidelines](https://github.com/auth0/open-source-template/blob/master/GENERAL-CONTRIBUTING.md). +- [ ] I have reviewed my own code beforehand. +- [ ] I have added documentation for new/changed functionality in this PR. +- [ ] All active GitHub checks for tests, formatting, and security are passing. +- [ ] The correct base branch is being used, if not `master`. diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index eaef4fdb..2b5ffb26 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -13,15 +13,15 @@ jobs: runs-on: ubuntu-latest steps: - name: install go - uses: actions/setup-go@v1 + uses: actions/setup-go@v2 with: - go-version: 1.14 + go-version: 1.17 - name: checkout code uses: actions/checkout@v2 - name: golangci-lint uses: golangci/golangci-lint-action@v2 with: - args: -v --timeout=5m --exclude SA1029 + args: -v --timeout=5m skip-build-cache: true skip-go-installation: true skip-pkg-cache: true diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index e57dd210..c29c4ce8 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -15,9 +15,9 @@ jobs: runs-on: ${{ matrix.os }} steps: - name: install go - uses: actions/setup-go@v1 + uses: actions/setup-go@v2 with: - go-version: 1.14 + go-version: 1.17 - name: checkout code uses: actions/checkout@v2 - name: test diff --git a/MIGRATION_GUIDE.md b/MIGRATION_GUIDE.md new file mode 100644 index 00000000..7660091b --- /dev/null +++ b/MIGRATION_GUIDE.md @@ -0,0 +1,112 @@ +# Migration Guide + +This guide covers the migration from [v1](https://github.com/auth0/go-jwt-middleware/tree/v1.0.1). + +### `jwtmiddleware.Options` + +Now handled by individual [jwtmiddleware.Option](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#Option) items. +They can be passed to [jwtmiddleware.New](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#New) after the +[jwtmiddleware.ValidateToken](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#ValidateToken) input: + +```golang +jwtmiddleware.New(validator, WithCredentialsOptional(true), ...) +``` + +#### `ValidationKeyGetter` + +Token validation is now handled via a token provider which can be learned about in the section on +[jwtmiddleware.New](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#New). + +#### `UserProperty` + +This is now handled in the validation provider. + +#### `ErrorHandler` + +We now provide a public [jwtmiddleware.ErrorHandler](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#ErrorHandler) +type: + +```golang +type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error) +``` + +A [default](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#DefaultErrorHandler) is provided which translates +errors into appropriate HTTP status codes. + +You might want to wrap the default, so you can hook things into, like logging: + +```golang +myErrHandler := func(w http.ResponseWriter, r *http.Request, err error) { + fmt.Printf("error in token validation: %+v\n", err) + + jwtmiddleware.DefaultErrorHandler(w, r, err) +} + +jwtMiddleware := jwtmiddleware.New(validator.ValidateToken, jwtmiddleware.WithErrorHandler(myErrHandler)) +``` + +#### `CredentialsOptional` + +Use the option function +[jwtmiddleware.WithCredentialsOptional(true|false)](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#WithCredentialsOptional). +Default is false. + +#### `Extractor` + +Use the option function [jwtmiddleware.WithTokenExtractor](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#WithTokenExtractor). +Default is to extract tokens from the auth header. + +We provide 3 different token extractors: +- [jwtmiddleware.AuthHeaderTokenExtractor](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#AuthHeaderTokenExtractor) renamed from `jwtmiddleware.FromAuthHeader`. +- [jwtmiddleware.CookieTokenExtractor](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#CookieTokenExtractor) a new extractor. +- [jwtmiddleware.ParameterTokenExtractor](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#ParameterTokenExtractor) renamed from `jwtmiddleware.FromParameter`. + +And also an extractor which can combine multiple different extractors together: +[jwtmiddleware.MultiTokenExtractor](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#MultiTokenExtractor) renamed from `jwtmiddleware.FromFirst`. + +#### `Debug` + +Removed. Please review individual exception messages for error details. + +#### `EnableAuthOnOptions` + +Use the option function [jwtmiddleware.WithValidateOnOptions(true|false)](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#WithValidateOnOptions). Default is true. + +#### `SigningMethod` + +This is now handled in the validation provider. + +### `jwtmiddleware.New` + +A token provider is set up in the middleware by passing a +[jwtmiddleware.ValidateToken](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#ValidateToken) +function: + +```golang +func(context.Context, string) (interface{}, error) +``` + +to [jwtmiddleware.New](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#New). + +In the example above you can see +[github.com/auth0/go-jwt-middleware/validate/josev2](https://pkg.go.dev/github.com/auth0/go-jwt-middleware@v2.0.0/validate/josev2) +being used. + +This change was made to allow the JWT validation provider to be easily switched out. + +Options are passed into `jwtmiddleware.New` after validation provider and use the `jwtmiddleware.With...` functions to +set options. + +### `jwtmiddleware.Handler*` + +Both `jwtmiddleware.HandlerWithNext` and `jwtmiddleware.Handler` have been dropped. +You can use [jwtmiddleware.CheckJWT](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#JWTMiddleware.CheckJWT) +instead which takes in an `http.Handler` and returns an `http.Handler`. + +### `jwtmiddleware.CheckJWT` + +This function has been reworked to be the main middleware handler piece, and so we've dropped the functionality of it +returning and error. + +If you need to handle any errors please use the +[jwtmiddleware.WithErrorHandler](https://pkg.go.dev/github.com/auth0/go-jwt-middleware#WithErrorHandler) function. diff --git a/Makefile b/Makefile index 2457b8c5..6e2e3585 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ test: ## Run tests. .PHONY: lint lint: ## Run golangci-lint. - golangci-lint run -v --timeout=5m --exclude SA1029 + golangci-lint run -v --timeout=5m .PHONY: help help: diff --git a/README.md b/README.md index c0fba09a..6fd8b71a 100644 --- a/README.md +++ b/README.md @@ -1,240 +1,158 @@ # GO JWT Middleware -### :mega: testers wanted :mega: -We are looking for testers for a new major version of this package. We've been working hard on the new version and want to get it tested out by users before we officially release it. For details on how to test it out please see [this](https://github.com/auth0/go-jwt-middleware/issues/86#issuecomment-881737547) issue comment. +[![GoDoc](https://pkg.go.dev/badge/github.com/auth0/go-jwt-middleware.svg)](https://pkg.go.dev/github.com/auth0/go-jwt-middleware) +[![License](https://img.shields.io/github/license/auth0/go-jwt-middleware.svg)](https://github.com/auth0/go-jwt-middleware/blob/master/LICENSE) +[![Release](https://img.shields.io/github/v/release/auth0/go-jwt-middleware)](https://github.com/auth0/go-jwt-middleware/releases/latest) +[![Codecov](https://codecov.io/gh/auth0/go-jwt-middleware/branch/master/graph/badge.svg?token=fs2WrOXe9H)](https://codecov.io/gh/auth0/go-jwt-middleware) +[![Tests](https://github.com/auth0/go-jwt-middleware/actions/workflows/test.yaml/badge.svg)](https://github.com/auth0/go-jwt-middleware/actions/workflows/test.yaml?query=branch%3Amaster) +[![Stars](https://img.shields.io/github/stars/auth0/go-jwt-middleware.svg)](https://github.com/auth0/go-jwt-middleware/stargazers) +[![Contributors](https://img.shields.io/github/contributors/auth0/go-jwt-middleware)](https://github.com/auth0/go-jwt-middleware/graphs/contributors) -In this release we’ve addressed some long-standing asks and made some major improvements: -- Replaceable JWT validation - you can now bring your favorite JWT package to validate tokens by ensuring it conforms to a simple interface. We provide two implementations for two different JWT packages. -- We now support a custom error handler. -- Under the hood we clone the `http.Request` instead of a shallow copy in order to better support reverse proxies. -- We now support extracting JWTs from cookies. -- We now store the JWT information using a non-string context key to conform to Golang best practices. -- A caching provider for JWKS is now provided to help you with rate limits from your identity provider. -- We’ve switched errors to use github.com/pkg/errors to provide better error context. If you’re not familiar with the package, don’t worry as it adheres to the error interface. +Golang middleware to check and validate [JWTs](jwt.io) in the request and add the valid token contents to the request +context. ---- +------------------------------------- -**NOTE:** We released this version using a fork of jwt-go in order to address a security vulnerability. Due to jwt-go not being actively maintained we will be looking to switch to a more actively maintained package in the near future. +## Table of Contents -A middleware that will check that a [JWT](http://jwt.io/) is sent on the `Authorization` header and will then set the content of the JWT into the `user` variable of the request. +- [Installation](#installation) +- [Usage](#usage) +- [Migration Guide](#migration-guide) +- [Issue Reporting](#issue-reporting) +- [Author](#author) +- [License](#license) -This module lets you authenticate HTTP requests using JWT tokens in your Go Programming Language applications. JWTs are typically used to protect API endpoints, and are often issued using OpenID Connect. +------------------------------------- -## Key Features +## Installation -* Ability to **check the `Authorization` header for a JWT** -* **Decode the JWT** and set the content of it to the request context - -## Installing - -````bash +```shell go get github.com/auth0/go-jwt-middleware -```` - -## Using it - -You can use `jwtmiddleware` with default `net/http` as follows. - -````go -// main.go -package main - -import ( - "fmt" - "net/http" - - "github.com/auth0/go-jwt-middleware" - "github.com/form3tech-oss/jwt-go" -) - -var myHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user") - fmt.Fprintf(w, "This is an authenticated request") - fmt.Fprintf(w, "Claim content:\n") - for k, v := range user.(*jwt.Token).Claims.(jwt.MapClaims) { - fmt.Fprintf(w, "%s :\t%#v\n", k, v) - } -}) +``` -func main() { - jwtMiddleware := jwtmiddleware.New(jwtmiddleware.Options{ - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - return []byte("My Secret"), nil - }, - // When set, the middleware verifies that tokens are signed with the specific signing algorithm - // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks - // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ - SigningMethod: jwt.SigningMethodHS256, - }) - - app := jwtMiddleware.Handler(myHandler) - http.ListenAndServe("0.0.0.0:3000", app) -} -```` +[[table of contents]](#table-of-contents) -You can also use it with Negroni as follows: +## Usage -````go -// main.go +```golang package main import ( + "context" "encoding/json" + "log" "net/http" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" + "github.com/auth0/go-jwt-middleware" - "github.com/form3tech-oss/jwt-go" - "github.com/gorilla/mux" - "github.com/urfave/negroni" + "github.com/auth0/go-jwt-middleware/validate/josev2" ) -func main() { - StartServer() -} - -func StartServer() { - r := mux.NewRouter() - - jwtMiddleware := jwtmiddleware.New(jwtmiddleware.Options{ - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - return []byte("My Secret"), nil - }, - SigningMethod: jwt.SigningMethodHS256, - }) - - r.HandleFunc("/ping", PingHandler) - r.Handle("/secured/ping", negroni.New( - negroni.HandlerFunc(jwtMiddleware.HandlerWithNext), - negroni.Wrap(http.HandlerFunc(SecuredPingHandler)), - )) - http.Handle("/", r) - http.ListenAndServe(":3001", nil) -} - -type Response struct { - Text string `json:"text"` -} - -func respondJSON(text string, w http.ResponseWriter) { - response := Response{text} +var handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims := r.Context().Value(jwtmiddleware.ContextKey{}).(*josev2.UserContext) - jsonResponse, err := json.Marshal(response) + payload, err := json.Marshal(claims) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") - w.Write(jsonResponse) -} + w.Write(payload) +}) -func PingHandler(w http.ResponseWriter, r *http.Request) { - respondJSON("All good. You don't need to be authenticated to call this", w) -} +func main() { + keyFunc := func(ctx context.Context) (interface{}, error) { + // Our token must be signed using this data. + return []byte("secret"), nil + } -func SecuredPingHandler(w http.ResponseWriter, r *http.Request) { - respondJSON("All good. You only get this message if you're authenticated", w) -} -```` - -## Options - -````go -// Options is a struct for specifying configuration options for the middleware. -type Options struct { - // The function that will return the Key to validate the JWT. - // It can be either a shared secret or a public key. - // Default value: nil - ValidationKeyGetter jwt.Keyfunc - // The name of the property in the request where the user information - // from the JWT will be stored. - // Default value: "user" - UserProperty string - // The function that will be called when there's an error validating the token - // Default value: - ErrorHandler errorHandler - // A boolean indicating if the credentials are required or not - // Default value: false - CredentialsOptional bool - // A function that extracts the token from the request - // Default: FromAuthHeader (i.e., from Authorization header as bearer token) - Extractor TokenExtractor - // Debug flag turns on debugging output - // Default: false - Debug bool - // When set, all requests with the OPTIONS method will use authentication - // Default: false - EnableAuthOnOptions bool - // When set, the middelware verifies that tokens are signed with the specific signing algorithm - // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks - // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ - // Default: nil - SigningMethod jwt.SigningMethod -} -```` + expectedClaimsFunc := func() jwt.Expected { + // By setting up expected claims we are saying + // a token must have the data we specify. + return jwt.Expected{ + Issuer: "josev2-example", + } + } -### Token Extraction + // Set up the josev2 validator. + validator, err := josev2.New( + keyFunc, + jose.HS256, + josev2.WithExpectedClaims(expectedClaimsFunc), + ) + if err != nil { + log.Fatalf("failed to set up the josev2 validator: %v", err) + } -The default value for the `Extractor` option is the `FromAuthHeader` -function which assumes that the JWT will be provided as a bearer token -in an `Authorization` header, i.e., + // Set up the middleware. + middleware := jwtmiddleware.New(validator.ValidateToken) -``` -Authorization: bearer {token} + http.ListenAndServe("0.0.0.0:3000", middleware.CheckJWT(handler)) +} ``` -To extract the token from a query string parameter, you can use the -`FromParameter` function, e.g., +After running that code (`go run main.go`) you can then curl the http server from another terminal: -```go -jwtmiddleware.New(jwtmiddleware.Options{ - Extractor: jwtmiddleware.FromParameter("auth_code"), -}) +``` +$ curl -H "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJpc3MiOiJqb3NldjItZXhhbXBsZSJ9.e0lGglk9-m-n-t07eA5f7qgXGM-nD4ekwJkYVKprIUM" localhost:3000 ``` -In this case, the `FromParameter` function will look for a JWT in the -`auth_code` query parameter. - -Or, if you want to allow both, you can use the `FromFirst` function to -try and extract the token first in one way and then in one or more -other ways, e.g., +That should give you the following response: -```go -jwtmiddleware.New(jwtmiddleware.Options{ - Extractor: jwtmiddleware.FromFirst(jwtmiddleware.FromAuthHeader, - jwtmiddleware.FromParameter("auth_code")), -}) +``` +{ + "CustomClaims": null, + "RegisteredClaims": { + "iss": "josev2-example", + "sub": "1234567890", + "iat": 1516239022 + } +} ``` -## Examples +The JWT included in the Authorization header above is signed with `secret`. -You can check out working examples in the [examples folder](https://github.com/auth0/go-jwt-middleware/tree/master/examples) +To test how the response would look like with an invalid token: +``` +$ curl -v -H "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.yiDw9IDNCa1WXCoDfPR_g356vSsHBEerqh9IvnD49QE" localhost:3000 +``` -## What is Auth0? +That should give you the following response: -Auth0 helps you to: +``` +... +< HTTP/1.1 401 Unauthorized +< Content-Type: application/json +{"message":"JWT is invalid."} +... +``` + +[[table of contents]](#table-of-contents) -* Add authentication with [multiple authentication sources](https://docs.auth0.com/identityproviders), either social like **Google, Facebook, Microsoft Account, LinkedIn, GitHub, Twitter, Box, Salesforce, amont others**, or enterprise identity systems like **Windows Azure AD, Google Apps, Active Directory, ADFS or any SAML Identity Provider**. -* Add authentication through more traditional **[username/password databases](https://docs.auth0.com/mysql-connection-tutorial)**. -* Add support for **[linking different user accounts](https://docs.auth0.com/link-accounts)** with the same user. -* Support for generating signed [Json Web Tokens](https://docs.auth0.com/jwt) to call your APIs and **flow the user identity** securely. -* Analytics of how, when and where users are logging in. -* Pull data from other sources and add it to the user profile, through [JavaScript rules](https://docs.auth0.com/rules). +## Migration Guide -## Create a free Auth0 Account +If you are moving from v1 to v2 please check our [migration guide](MIGRATION_GUIDE.md). -1. Go to [Auth0](https://auth0.com) and click Sign Up. -2. Use Google, GitHub or Microsoft Account to login. +[[table of contents]](#table-of-contents) ## Issue Reporting If you have found a bug or if you have a feature request, please report them at this repository issues section. Please do not report security vulnerabilities on the public GitHub issue tracker. The [Responsible Disclosure Program](https://auth0.com/whitehat) details the procedure for disclosing security issues. +[[table of contents]](#table-of-contents) + ## Author -[Auth0](auth0.com) +[Auth0](https://auth0.com/) + +[[table of contents]](#table-of-contents) ## License This project is licensed under the MIT license. See the [LICENSE](LICENSE) file for more info. + +[[table of contents]](#table-of-contents) diff --git a/error_handler.go b/error_handler.go new file mode 100644 index 00000000..d1795bfb --- /dev/null +++ b/error_handler.go @@ -0,0 +1,70 @@ +package jwtmiddleware + +import ( + "fmt" + "net/http" + + "github.com/pkg/errors" +) + +var ( + // ErrJWTMissing is returned when the JWT is missing. + ErrJWTMissing = errors.New("jwt missing") + + // ErrJWTInvalid is returned when the JWT is invalid. + ErrJWTInvalid = errors.New("jwt invalid") +) + +// ErrorHandler is a handler which is called when an error occurs in the +// JWTMiddleware. Among some general errors, this handler also determines the +// response of the JWTMiddleware when a token is not found or is invalid. The +// err can be checked to be ErrJWTMissing or ErrJWTInvalid for specific cases. +// The default handler will return a status code of 400 for ErrJWTMissing, +// 401 for ErrJWTInvalid, and 500 for all other errors. If you implement your +// own ErrorHandler you MUST take into consideration the error types as not +// properly responding to them or having a poorly implemented handler could +// result in the JWTMiddleware not functioning as intended. +type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error) + +// DefaultErrorHandler is the default error handler implementation for the +// JWTMiddleware. If an error handler is not provided via the WithErrorHandler +// option this will be used. +func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) { + w.Header().Set("Content-Type", "application/json") + + switch { + case errors.Is(err, ErrJWTMissing): + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message":"JWT is missing."}`)) + case errors.Is(err, ErrJWTInvalid): + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message":"JWT is invalid."}`)) + default: + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"message":"Something went wrong while checking the JWT."}`)) + } +} + +// invalidError handles wrapping a JWT validation error with +// the concrete error ErrJWTInvalid. We do not expose this +// publicly because the interface methods of Is and Unwrap +// should give the user all they need. +type invalidError struct { + details error +} + +// Is allows the error to support equality to ErrJWTInvalid. +func (e *invalidError) Is(target error) bool { + return target == ErrJWTInvalid +} + +// Error returns a string representation of the error. +func (e *invalidError) Error() string { + return fmt.Sprintf("%s: %s", ErrJWTInvalid, e.details) +} + +// Unwrap allows the error to support equality to the +// underlying error and not just ErrJWTInvalid. +func (e *invalidError) Unwrap() error { + return e.details +} diff --git a/examples/http-example/README.md b/examples/http-example/README.md index a3909d35..a01385d3 100644 --- a/examples/http-example/README.md +++ b/examples/http-example/README.md @@ -4,6 +4,6 @@ This is an example of how to use the http middleware. # Using it -To try this out, first install all dependencies with `go install` and then run `go run main.go` to start the app. +To try this out, first install all dependencies with `go mod download` and then run `go run main.go` to start the app. -* Call `http://localhost:3000` with a JWT signed with `My Secret` to get a response back. +* Call `http://localhost:3000` with a JWT signed with `My Secret` (you can use [jwt.io](https://jwt.io/) for this) to get a response back. diff --git a/examples/http-example/main.go b/examples/http-example/main.go index 79724ae8..664cd621 100644 --- a/examples/http-example/main.go +++ b/examples/http-example/main.go @@ -1,33 +1,57 @@ package main import ( - "fmt" + "context" + "encoding/json" + "log" "net/http" - jwtmiddleware "github.com/auth0/go-jwt-middleware" - "github.com/form3tech-oss/jwt-go" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" + + "github.com/auth0/go-jwt-middleware" + "github.com/auth0/go-jwt-middleware/validate/josev2" ) -var myHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user") - fmt.Fprintf(w, "This is an authenticated request") - fmt.Fprintf(w, "Claim content:\n") - for k, v := range user.(*jwt.Token).Claims.(jwt.MapClaims) { - fmt.Fprintf(w, "%s :\t%#v\n", k, v) +var handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims := r.Context().Value(jwtmiddleware.ContextKey{}).(*josev2.UserContext) + + payload, err := json.Marshal(claims) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return } + + w.Header().Set("Content-Type", "application/json") + w.Write(payload) }) func main() { - jwtMiddleware := jwtmiddleware.New(jwtmiddleware.Options{ - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - return []byte("My Secret"), nil - }, - // When set, the middleware verifies that tokens are signed with the specific signing algorithm - // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks - // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ - SigningMethod: jwt.SigningMethodHS256, - }) - - app := jwtMiddleware.Handler(myHandler) - http.ListenAndServe("0.0.0.0:3000", app) + keyFunc := func(ctx context.Context) (interface{}, error) { + // Our token must be signed using this data. + return []byte("secret"), nil + } + + expectedClaimsFunc := func() jwt.Expected { + // By setting up expected claims we are saying + // a token must have the data we specify. + return jwt.Expected{ + Issuer: "josev2-example", + } + } + + // Set up the josev2 validator. + validator, err := josev2.New( + keyFunc, + jose.HS256, + josev2.WithExpectedClaims(expectedClaimsFunc), + ) + if err != nil { + log.Fatalf("failed to set up the josev2 validator: %v", err) + } + + // Set up the middleware. + middleware := jwtmiddleware.New(validator.ValidateToken) + + http.ListenAndServe("0.0.0.0:3000", middleware.CheckJWT(handler)) } diff --git a/examples/http-jwks-example/README.md b/examples/http-jwks-example/README.md new file mode 100644 index 00000000..33214b54 --- /dev/null +++ b/examples/http-jwks-example/README.md @@ -0,0 +1,15 @@ +# HTTP JWKS example + +This is an example of how to use the http middleware with JWKS. + +# Using it + +To try this out: +1. Install all dependencies with `go mod download` +1. Go to https://manage.auth0.com/ and create a new API. +1. Go to the "Test" tab of the API and copy the cURL example. +1. Run the cURL example in your terminal and copy the `access_token` from the response. The tool jq can be helpful for this. +1. In the example change `` on line 30 to the domain used in the cURL request. +1. Run the example with `go run main.go`. +1. In a new terminal use cURL to talk to the API: `curl -v --request GET --url http://localhost:3000` +1. Now try it again with the `access_token` you copied earlier and run `curl -v --request GET --url http://localhost:3000 --header "authorization: Bearer $TOKEN"` to see a successful request. diff --git a/examples/http-jwks-example/main.go b/examples/http-jwks-example/main.go new file mode 100644 index 00000000..77ca7075 --- /dev/null +++ b/examples/http-jwks-example/main.go @@ -0,0 +1,50 @@ +package main + +import ( + "encoding/json" + "log" + "net/http" + "net/url" + "time" + + "gopkg.in/square/go-jose.v2" + + "github.com/auth0/go-jwt-middleware" + "github.com/auth0/go-jwt-middleware/validate/josev2" +) + +var handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims := r.Context().Value(jwtmiddleware.ContextKey{}).(*josev2.UserContext) + + payload, err := json.Marshal(claims) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write(payload) +}) + +func main() { + issuerURL, err := url.Parse("https://") + if err != nil { + log.Fatalf("failed to parse the issuer url: %v", err) + } + + provider := josev2.NewCachingJWKSProvider(*issuerURL, 5*time.Minute) + + // Set up the josev2 validator. + validator, err := josev2.New( + provider.KeyFunc, + jose.RS256, + ) + if err != nil { + log.Fatalf("failed to set up the josev2 validator: %v", err) + } + + // Set up the middleware. + middleware := jwtmiddleware.New(validator.ValidateToken) + + http.ListenAndServe("0.0.0.0:3000", middleware.CheckJWT(handler)) +} diff --git a/examples/martini-example/README.md b/examples/martini-example/README.md deleted file mode 100644 index 62394f4f..00000000 --- a/examples/martini-example/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Martini example - -This is an example of how to use the middleware with Martini. - -# Using it - -To try this out, first install all dependencies with `go install` and then run `go run main.go` to start the app. - -* Call `http://localhost:3001/ping` to get a JSon response without the need of a JWT. -* Call `http://localhost:3001/secured/ping` with a JWT signed with `My Secret` to get a response back. diff --git a/examples/martini-example/main.go b/examples/martini-example/main.go deleted file mode 100644 index 7a9cc9fb..00000000 --- a/examples/martini-example/main.go +++ /dev/null @@ -1,59 +0,0 @@ -package main - -import ( - "encoding/json" - "net/http" - - jwtmiddleware "github.com/auth0/go-jwt-middleware" - "github.com/form3tech-oss/jwt-go" - "github.com/go-martini/martini" -) - -func main() { - - StartServer() - -} - -func StartServer() { - m := martini.Classic() - - jwtMiddleware := jwtmiddleware.New(jwtmiddleware.Options{ - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - return []byte("My Secret"), nil - }, - SigningMethod: jwt.SigningMethodHS256, - }) - - m.Get("/ping", PingHandler) - m.Get("/secured/ping", func(w http.ResponseWriter, r *http.Request) { - jwtMiddleware.CheckJWT(w, r) - }, SecuredPingHandler) - - m.Run() -} - -type Response struct { - Text string `json:"text"` -} - -func respondJSON(text string, w http.ResponseWriter) { - response := Response{text} - - jsonResponse, err := json.Marshal(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - w.Write(jsonResponse) -} - -func PingHandler(w http.ResponseWriter, r *http.Request) { - respondJSON("All good. You don't need to be authenticated to call this", w) -} - -func SecuredPingHandler(w http.ResponseWriter, r *http.Request) { - respondJSON("All good. You only get this message if you're authenticated", w) -} diff --git a/examples/negroni-example/README.md b/examples/negroni-example/README.md deleted file mode 100644 index 0e1a035d..00000000 --- a/examples/negroni-example/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Negroni example - -This is an example of how to use the Negroni middleware. - -# Using it - -To try this out, first install all dependencies with `go install` and then run `go run main.go` to start the app. - -* Call `http://localhost:3001/ping` to get a JSon response without the need of a JWT. -* Call `http://localhost:3001/secured/ping` with a JWT signed with `My Secret` to get a response back. \ No newline at end of file diff --git a/examples/negroni-example/main.go b/examples/negroni-example/main.go deleted file mode 100644 index a2640a4d..00000000 --- a/examples/negroni-example/main.go +++ /dev/null @@ -1,61 +0,0 @@ -package main - -import ( - "encoding/json" - "net/http" - - jwtmiddleware "github.com/auth0/go-jwt-middleware" - "github.com/form3tech-oss/jwt-go" - "github.com/gorilla/mux" - "github.com/urfave/negroni" -) - -func main() { - - StartServer() - -} - -func StartServer() { - r := mux.NewRouter() - - jwtMiddleware := jwtmiddleware.New(jwtmiddleware.Options{ - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - return []byte("My Secret"), nil - }, - SigningMethod: jwt.SigningMethodHS256, - }) - - r.HandleFunc("/ping", PingHandler) - r.Handle("/secured/ping", negroni.New( - negroni.HandlerFunc(jwtMiddleware.HandlerWithNext), - negroni.Wrap(http.HandlerFunc(SecuredPingHandler)), - )) - http.Handle("/", r) - http.ListenAndServe(":3001", nil) -} - -type Response struct { - Text string `json:"text"` -} - -func respondJSON(text string, w http.ResponseWriter) { - response := Response{text} - - jsonResponse, err := json.Marshal(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - w.Write(jsonResponse) -} - -func PingHandler(w http.ResponseWriter, r *http.Request) { - respondJSON("All good. You don't need to be authenticated to call this", w) -} - -func SecuredPingHandler(w http.ResponseWriter, r *http.Request) { - respondJSON("All good. You only get this message if you're authenticated", w) -} diff --git a/extractor.go b/extractor.go new file mode 100644 index 00000000..9cbb3268 --- /dev/null +++ b/extractor.go @@ -0,0 +1,75 @@ +package jwtmiddleware + +import ( + "net/http" + "strings" + + "github.com/pkg/errors" +) + +// TokenExtractor is a function that takes a request as input and returns +// either a token or an error. An error should only be returned if an attempt +// to specify a token was found, but the information was somehow incorrectly +// formed. In the case where a token is simply not present, this should not +// be treated as an error. An empty string should be returned in that case. +type TokenExtractor func(r *http.Request) (string, error) + +// AuthHeaderTokenExtractor is a TokenExtractor that takes a request +// and extracts the token from the Authorization header. +func AuthHeaderTokenExtractor(r *http.Request) (string, error) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + return "", nil // No error, just no JWT. + } + + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "", errors.New("Authorization header format must be Bearer {token}") + } + + return authHeaderParts[1], nil +} + +// CookieTokenExtractor builds a TokenExtractor that takes a request and +// extracts the token from the cookie using the passed in cookieName. +func CookieTokenExtractor(cookieName string) TokenExtractor { + return func(r *http.Request) (string, error) { + cookie, err := r.Cookie(cookieName) + if err != nil { + return "", err + } + + if cookie != nil { + return cookie.Value, nil + } + + return "", nil // No error, just no JWT. + } +} + +// ParameterTokenExtractor returns a TokenExtractor that extracts +// the token from the specified query string parameter. +func ParameterTokenExtractor(param string) TokenExtractor { + return func(r *http.Request) (string, error) { + return r.URL.Query().Get(param), nil + } +} + +// MultiTokenExtractor returns a TokenExtractor that runs multiple TokenExtractors +// and takes the one that does not return an empty token. If a TokenExtractor +// returns an error that error is immediately returned. +func MultiTokenExtractor(extractors ...TokenExtractor) TokenExtractor { + return func(r *http.Request) (string, error) { + for _, ex := range extractors { + token, err := ex(r) + if err != nil { + return "", err + } + + if token != "" { + return token, nil + } + } + return "", nil + } +} diff --git a/go.mod b/go.mod index e103a6e8..f24fe63e 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,11 @@ module github.com/auth0/go-jwt-middleware -go 1.14 +go 1.17 require ( - github.com/form3tech-oss/jwt-go v3.2.2+incompatible - github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 // indirect - github.com/gorilla/mux v1.7.4 - github.com/smartystreets/assertions v1.1.0 // indirect - github.com/smartystreets/goconvey v1.6.4 - github.com/urfave/negroni v1.0.0 + github.com/golang-jwt/jwt/v4 v4.1.0 + github.com/google/go-cmp v0.5.6 + github.com/pkg/errors v0.9.1 + gopkg.in/square/go-jose.v2 v2.5.1 + golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 // indirect ) diff --git a/go.sum b/go.sum index 0e1e3bef..fdb87599 100644 --- a/go.sum +++ b/go.sum @@ -1,21 +1,28 @@ -github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= -github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 h1:l5lAOZEym3oK3SQ2HBHWsJUfbNBiTXJDeW2QDxw9AQ0= -github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc= -github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= -github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/assertions v1.1.0 h1:MkTeG1DMwsrdH7QtLXy5W+fUxWq+vmb6cLmyJ7aRtF0= -github.com/smartystreets/assertions v1.1.0/go.mod h1:tcbTF8ujkAEcZ8TElKY+i30BzYlVhC/LOxJk7iOWnoo= -github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= -github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= -github.com/urfave/negroni v1.0.0 h1:kIimOitoypq34K7TG7DUaJ9kq/N4Ofuwi1sjz0KipXc= -github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v4 v4.1.0 h1:XUgk2Ex5veyVFVeLm0xhusUTQybEbexJXrvPNOKkSY0= +github.com/golang-jwt/jwt/v4 v4.1.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w= +gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go new file mode 100644 index 00000000..aee4081c --- /dev/null +++ b/internal/oidc/oidc.go @@ -0,0 +1,39 @@ +package oidc + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "path" +) + +// WellKnownEndpoints holds the well known OIDC endpoints. +type WellKnownEndpoints struct { + JWKSURI string `json:"jwks_uri"` +} + +// GetWellKnownEndpointsFromIssuerURL gets the well known endpoints for the passed in issuer url. +func GetWellKnownEndpointsFromIssuerURL(ctx context.Context, issuerURL url.URL) (*WellKnownEndpoints, error) { + issuerURL.Path = path.Join(issuerURL.Path, ".well-known/openid-configuration") + + request, err := http.NewRequest(http.MethodGet, issuerURL.String(), nil) + if err != nil { + return nil, fmt.Errorf("could not build request to get well known endpoints: %w", err) + } + request = request.WithContext(ctx) + + response, err := http.DefaultClient.Do(request) + if err != nil { + return nil, fmt.Errorf("could not get well known endpoints from url %s: %w", issuerURL.String(), err) + } + defer response.Body.Close() + + var wkEndpoints WellKnownEndpoints + if err = json.NewDecoder(response.Body).Decode(&wkEndpoints); err != nil { + return nil, fmt.Errorf("could not decode json body when getting well known endpoints: %w", err) + } + + return &wkEndpoints, nil +} diff --git a/jwtmiddleware.go b/jwtmiddleware.go deleted file mode 100644 index 6e457ba9..00000000 --- a/jwtmiddleware.go +++ /dev/null @@ -1,237 +0,0 @@ -package jwtmiddleware - -import ( - "context" - "errors" - "fmt" - "log" - "net/http" - "strings" - - "github.com/form3tech-oss/jwt-go" -) - -// A function called whenever an error is encountered -type errorHandler func(w http.ResponseWriter, r *http.Request, err string) - -// TokenExtractor is a function that takes a request as input and returns -// either a token or an error. An error should only be returned if an attempt -// to specify a token was found, but the information was somehow incorrectly -// formed. In the case where a token is simply not present, this should not -// be treated as an error. An empty string should be returned in that case. -type TokenExtractor func(r *http.Request) (string, error) - -// Options is a struct for specifying configuration options for the middleware. -type Options struct { - // The function that will return the Key to validate the JWT. - // It can be either a shared secret or a public key. - // Default value: nil - ValidationKeyGetter jwt.Keyfunc - // The name of the property in the request where the user information - // from the JWT will be stored. - // Default value: "user" - UserProperty string - // The function that will be called when there's an error validating the token - // Default value: - ErrorHandler errorHandler - // A boolean indicating if the credentials are required or not - // Default value: false - CredentialsOptional bool - // A function that extracts the token from the request - // Default: FromAuthHeader (i.e., from Authorization header as bearer token) - Extractor TokenExtractor - // Debug flag turns on debugging output - // Default: false - Debug bool - // When set, all requests with the OPTIONS method will use authentication - // Default: false - EnableAuthOnOptions bool - // When set, the middelware verifies that tokens are signed with the specific signing algorithm - // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks - // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ - // Default: nil - SigningMethod jwt.SigningMethod -} - -type JWTMiddleware struct { - Options Options -} - -func OnError(w http.ResponseWriter, r *http.Request, err string) { - http.Error(w, err, http.StatusUnauthorized) -} - -// New constructs a new Secure instance with supplied options. -func New(options ...Options) *JWTMiddleware { - - var opts Options - if len(options) == 0 { - opts = Options{} - } else { - opts = options[0] - } - - if opts.UserProperty == "" { - opts.UserProperty = "user" - } - - if opts.ErrorHandler == nil { - opts.ErrorHandler = OnError - } - - if opts.Extractor == nil { - opts.Extractor = FromAuthHeader - } - - return &JWTMiddleware{ - Options: opts, - } -} - -func (m *JWTMiddleware) logf(format string, args ...interface{}) { - if m.Options.Debug { - log.Printf(format, args...) - } -} - -// HandlerWithNext is a special implementation for Negroni, but could be used elsewhere. -func (m *JWTMiddleware) HandlerWithNext(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - err := m.CheckJWT(w, r) - - // If there was an error, do not call next. - if err == nil && next != nil { - next(w, r) - } -} - -func (m *JWTMiddleware) Handler(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Let secure process the request. If it returns an error, - // that indicates the request should not continue. - err := m.CheckJWT(w, r) - - // If there was an error, do not continue. - if err != nil { - return - } - - h.ServeHTTP(w, r) - }) -} - -// FromAuthHeader is a "TokenExtractor" that takes a give request and extracts -// the JWT token from the Authorization header. -func FromAuthHeader(r *http.Request) (string, error) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "", nil // No error, just no token - } - - // TODO: Make this a bit more robust, parsing-wise - authHeaderParts := strings.Fields(authHeader) - if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { - return "", errors.New("Authorization header format must be Bearer {token}") - } - - return authHeaderParts[1], nil -} - -// FromParameter returns a function that extracts the token from the specified -// query string parameter -func FromParameter(param string) TokenExtractor { - return func(r *http.Request) (string, error) { - return r.URL.Query().Get(param), nil - } -} - -// FromFirst returns a function that runs multiple token extractors and takes the -// first token it finds -func FromFirst(extractors ...TokenExtractor) TokenExtractor { - return func(r *http.Request) (string, error) { - for _, ex := range extractors { - token, err := ex(r) - if err != nil { - return "", err - } - if token != "" { - return token, nil - } - } - return "", nil - } -} - -func (m *JWTMiddleware) CheckJWT(w http.ResponseWriter, r *http.Request) error { - if !m.Options.EnableAuthOnOptions { - if r.Method == "OPTIONS" { - return nil - } - } - - // Use the specified token extractor to extract a token from the request - token, err := m.Options.Extractor(r) - - // If debugging is turned on, log the outcome - if err != nil { - m.logf("Error extracting JWT: %v", err) - } else { - m.logf("Token extracted: %s", token) - } - - // If an error occurs, call the error handler and return an error - if err != nil { - m.Options.ErrorHandler(w, r, err.Error()) - return fmt.Errorf("Error extracting token: %w", err) - } - - // If the token is empty... - if token == "" { - // Check if it was required - if m.Options.CredentialsOptional { - m.logf(" No credentials found (CredentialsOptional=true)") - // No error, just no token (and that is ok given that CredentialsOptional is true) - return nil - } - - // If we get here, the required token is missing - errorMsg := "Required authorization token not found" - m.Options.ErrorHandler(w, r, errorMsg) - m.logf(" Error: No credentials found (CredentialsOptional=false)") - return fmt.Errorf(errorMsg) - } - - // Now parse the token - parsedToken, err := jwt.Parse(token, m.Options.ValidationKeyGetter) - - // Check if there was an error in parsing... - if err != nil { - m.logf("Error parsing token: %v", err) - m.Options.ErrorHandler(w, r, err.Error()) - return fmt.Errorf("Error parsing token: %w", err) - } - - if m.Options.SigningMethod != nil && m.Options.SigningMethod.Alg() != parsedToken.Header["alg"] { - message := fmt.Sprintf("Expected %s signing method but token specified %s", - m.Options.SigningMethod.Alg(), - parsedToken.Header["alg"]) - m.logf("Error validating token algorithm: %s", message) - m.Options.ErrorHandler(w, r, errors.New(message).Error()) - return fmt.Errorf("Error validating token algorithm: %s", message) - } - - // Check if the parsed token is valid... - if !parsedToken.Valid { - m.logf("Token is invalid") - m.Options.ErrorHandler(w, r, "The token isn't valid") - return errors.New("Token is invalid") - } - - m.logf("JWT: %v", parsedToken) - - // If we get here, everything worked and we can set the - // user property in context. - newRequest := r.WithContext(context.WithValue(r.Context(), m.Options.UserProperty, parsedToken)) - // Update the current request with the new context information. - *r = *newRequest - return nil -} diff --git a/jwtmiddleware_test.go b/jwtmiddleware_test.go deleted file mode 100644 index cab5cd51..00000000 --- a/jwtmiddleware_test.go +++ /dev/null @@ -1,216 +0,0 @@ -package jwtmiddleware - -import ( - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/form3tech-oss/jwt-go" - "github.com/gorilla/mux" - . "github.com/smartystreets/goconvey/convey" - "github.com/urfave/negroni" -) - -// defaultAuthorizationHeaderName is the default header name where the Auth -// token should be written -const defaultAuthorizationHeaderName = "Authorization" - -// userPropertyName is the property name that will be set in the request context -const userPropertyName = "custom-user-property" - -// the bytes read from the keys/sample-key file -// private key generated with http://kjur.github.io/jsjws/tool_jwt.html -var privateKey []byte - -// TestUnauthenticatedRequest will perform requests with no Authorization header -func TestUnauthenticatedRequest(t *testing.T) { - Convey("Simple unauthenticated request", t, func() { - Convey("Unauthenticated GET to / path should return a 200 response", func() { - w := makeUnauthenticatedRequest("GET", "/") - So(w.Code, ShouldEqual, http.StatusOK) - }) - Convey("Unauthenticated GET to /protected path should return a 401 response", func() { - w := makeUnauthenticatedRequest("GET", "/protected") - So(w.Code, ShouldEqual, http.StatusUnauthorized) - }) - }) -} - -// TestAuthenticatedRequest will perform requests with an Authorization header -func TestAuthenticatedRequest(t *testing.T) { - var e error - privateKey, e = readPrivateKey() - if e != nil { - panic(e) - } - Convey("Simple authenticated requests", t, func() { - Convey("Authenticated GET to / path should return a 200 response", func() { - w := makeAuthenticatedRequest("GET", "/", jwt.MapClaims{"foo": "bar"}, nil) - So(w.Code, ShouldEqual, http.StatusOK) - }) - Convey("Authenticated GET to /protected path should return a 200 response if expected algorithm is not specified", func() { - var expectedAlgorithm jwt.SigningMethod = nil - w := makeAuthenticatedRequest("GET", "/protected", jwt.MapClaims{"foo": "bar"}, expectedAlgorithm) - So(w.Code, ShouldEqual, http.StatusOK) - responseBytes, err := ioutil.ReadAll(w.Body) - if err != nil { - panic(err) - } - responseString := string(responseBytes) - // check that the encoded data in the jwt was properly returned as json - So(responseString, ShouldEqual, `{"text":"bar"}`) - }) - Convey("Authenticated GET to /protected path should return a 200 response if expected algorithm is correct", func() { - expectedAlgorithm := jwt.SigningMethodHS256 - w := makeAuthenticatedRequest("GET", "/protected", jwt.MapClaims{"foo": "bar"}, expectedAlgorithm) - So(w.Code, ShouldEqual, http.StatusOK) - responseBytes, err := ioutil.ReadAll(w.Body) - if err != nil { - panic(err) - } - responseString := string(responseBytes) - // check that the encoded data in the jwt was properly returned as json - So(responseString, ShouldEqual, `{"text":"bar"}`) - }) - Convey("Authenticated GET to /protected path should return a 401 response if algorithm is not expected one", func() { - expectedAlgorithm := jwt.SigningMethodRS256 - w := makeAuthenticatedRequest("GET", "/protected", jwt.MapClaims{"foo": "bar"}, expectedAlgorithm) - So(w.Code, ShouldEqual, http.StatusUnauthorized) - responseBytes, err := ioutil.ReadAll(w.Body) - if err != nil { - panic(err) - } - responseString := string(responseBytes) - // check that the encoded data in the jwt was properly returned as json - So(strings.TrimSpace(responseString), ShouldEqual, "Expected RS256 signing method but token specified HS256") - }) - }) -} - -func makeUnauthenticatedRequest(method string, url string) *httptest.ResponseRecorder { - return makeAuthenticatedRequest(method, url, nil, nil) -} - -func makeAuthenticatedRequest(method string, url string, c jwt.Claims, expectedSignatureAlgorithm jwt.SigningMethod) *httptest.ResponseRecorder { - r, _ := http.NewRequest(method, url, nil) - if c != nil { - token := jwt.New(jwt.SigningMethodHS256) - token.Claims = c - // private key generated with http://kjur.github.io/jsjws/tool_jwt.html - s, e := token.SignedString(privateKey) - if e != nil { - panic(e) - } - r.Header.Set(defaultAuthorizationHeaderName, fmt.Sprintf("bearer %v", s)) - } - w := httptest.NewRecorder() - n := createNegroniMiddleware(expectedSignatureAlgorithm) - n.ServeHTTP(w, r) - return w -} - -func createNegroniMiddleware(expectedSignatureAlgorithm jwt.SigningMethod) *negroni.Negroni { - // create a gorilla mux router for public requests - publicRouter := mux.NewRouter().StrictSlash(true) - publicRouter.Methods("GET"). - Path("/"). - Name("Index"). - Handler(http.HandlerFunc(indexHandler)) - - // create a gorilla mux route for protected requests - // the routes will be tested for jwt tokens in the default auth header - protectedRouter := mux.NewRouter().StrictSlash(true) - protectedRouter.Methods("GET"). - Path("/protected"). - Name("Protected"). - Handler(http.HandlerFunc(protectedHandler)) - // create a negroni handler for public routes - negPublic := negroni.New() - negPublic.UseHandler(publicRouter) - - // negroni handler for api request - negProtected := negroni.New() - //add the JWT negroni handler - negProtected.Use(negroni.HandlerFunc(JWT(expectedSignatureAlgorithm).HandlerWithNext)) - negProtected.UseHandler(protectedRouter) - - //Create the main router - mainRouter := mux.NewRouter().StrictSlash(true) - - mainRouter.Handle("/", negPublic) - mainRouter.Handle("/protected", negProtected) - //if routes match the handle prefix then I need to add this dummy matcher {_dummy:.*} - mainRouter.Handle("/protected/{_dummy:.*}", negProtected) - - n := negroni.Classic() - // This are the "GLOBAL" middlewares that will be applied to every request - // examples are listed below: - //n.Use(gzip.Gzip(gzip.DefaultCompression)) - //n.Use(negroni.HandlerFunc(SecurityMiddleware().HandlerFuncWithNext)) - n.UseHandler(mainRouter) - - return n -} - -// JWT creates the middleware that parses a JWT encoded token -func JWT(expectedSignatureAlgorithm jwt.SigningMethod) *JWTMiddleware { - return New(Options{ - Debug: false, - CredentialsOptional: false, - UserProperty: userPropertyName, - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - if privateKey == nil { - var err error - privateKey, err = readPrivateKey() - if err != nil { - panic(err) - } - } - return privateKey, nil - }, - SigningMethod: expectedSignatureAlgorithm, - }) -} - -// readPrivateKey will load the keys/sample-key file into the -// global privateKey variable -func readPrivateKey() ([]byte, error) { - privateKey, e := ioutil.ReadFile("keys/sample-key") - return privateKey, e -} - -// indexHandler will return an empty 200 OK response -func indexHandler(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) -} - -// protectedHandler will return the content of the "foo" encoded data -// in the token as json -> {"text":"bar"} -func protectedHandler(w http.ResponseWriter, r *http.Request) { - // retrieve the token from the context - u := r.Context().Value(userPropertyName) - user := u.(*jwt.Token) - respondJSON(user.Claims.(jwt.MapClaims)["foo"].(string), w) -} - -// Response quick n' dirty Response struct to be encoded as json -type Response struct { - Text string `json:"text"` -} - -// respondJSON will take an string to write through the writer as json -func respondJSON(text string, w http.ResponseWriter) { - response := Response{text} - - jsonResponse, err := json.Marshal(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write(jsonResponse) -} diff --git a/keys/sample-key b/keys/sample-key deleted file mode 100644 index 47f557ef..00000000 --- a/keys/sample-key +++ /dev/null @@ -1 +0,0 @@ -eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL2p3dC1pZHAuZXhhbXBsZS5jb20iLCJzdWIiOiJtYWlsdG86bWlrZUBleGFtcGxlLmNvbSIsIm5iZiI6MTQzMDc3OTMwNSwiZXhwIjoxNDYyMzE1MzA1LCJpYXQiOjE0MzA3NzkzMDUsImp0aSI6ImlkMTIzNDU2IiwidHlwIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS9yZWdpc3RlciJ9.KbVlagrOLiy-R65eUrVuno_IAjW-J5i_ySoSrs2SgjU diff --git a/middleware.go b/middleware.go new file mode 100644 index 00000000..89dcd3f7 --- /dev/null +++ b/middleware.go @@ -0,0 +1,92 @@ +package jwtmiddleware + +import ( + "context" + "fmt" + "net/http" +) + +// ContextKey is the key used in the request +// context where the information from a +// validated JWT will be stored. +type ContextKey struct{} + +type JWTMiddleware struct { + validateToken ValidateToken + errorHandler ErrorHandler + tokenExtractor TokenExtractor + credentialsOptional bool + validateOnOptions bool +} + +// ValidateToken takes in a string JWT and makes sure it is valid and +// returns the valid token. If it is not valid it will return nil and +// an error message describing why validation failed. +// Inside ValidateToken things like key and alg checking can happen. +// In the default implementation we can add safe defaults for those. +type ValidateToken func(context.Context, string) (interface{}, error) + +// New constructs a new JWTMiddleware instance with the supplied options. +// It requires a ValidateToken function to be passed in, so it can +// properly validate tokens. +func New(validateToken ValidateToken, opts ...Option) *JWTMiddleware { + m := &JWTMiddleware{ + validateToken: validateToken, + errorHandler: DefaultErrorHandler, + credentialsOptional: false, + tokenExtractor: AuthHeaderTokenExtractor, + validateOnOptions: true, + } + + for _, opt := range opts { + opt(m) + } + + return m +} + +// CheckJWT is the main JWTMiddleware function which performs the main logic. It +// is passed a http.Handler which will be called if the JWT passes validation. +func (m *JWTMiddleware) CheckJWT(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // If we don't validate on OPTIONS and this is OPTIONS + // then continue onto next without validating. + if !m.validateOnOptions && r.Method == http.MethodOptions { + next.ServeHTTP(w, r) + return + } + + token, err := m.tokenExtractor(r) + if err != nil { + // This is not ErrJWTMissing because an error here means that the + // tokenExtractor had an error and _not_ that the token was missing. + m.errorHandler(w, r, fmt.Errorf("error extracting token: %w", err)) + return + } + + if token == "" { + // If credentials are optional continue + // onto next without validating. + if m.credentialsOptional { + next.ServeHTTP(w, r) + return + } + + // Credentials were not optional so we error. + m.errorHandler(w, r, ErrJWTMissing) + return + } + + // Validate the token using the token validator. + validToken, err := m.validateToken(r.Context(), token) + if err != nil { + m.errorHandler(w, r, &invalidError{details: err}) + return + } + + // No err means we have a valid token, so set + // it into the context and continue onto next. + r = r.Clone(context.WithValue(r.Context(), ContextKey{}, validToken)) + next.ServeHTTP(w, r) + }) +} diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 00000000..2f39f460 --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,380 @@ +package jwtmiddleware + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/pkg/errors" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" + + "github.com/auth0/go-jwt-middleware/validate/josev2" +) + +func Test_CheckJWT(t *testing.T) { + var ( + validToken = "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0aW5nIn0.SdU_8KjnZsQChrVtQpYGxS48DxB4rTM9biq6D4haR70" + invalidToken = "bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0aW5nIn0.eM1Jd7VA7nFSI09FlmLmtuv7cLnv8qicZ8s76-jTOoE" + validContextToken = &josev2.UserContext{ + RegisteredClaims: jwt.Claims{ + Issuer: "testing", + }, + } + ) + + validator, err := josev2.New( + func(_ context.Context) (interface{}, error) { + return []byte("secret"), nil + }, + jose.HS256, + josev2.WithExpectedClaims( + func() jwt.Expected { + return jwt.Expected{Issuer: "testing"} + }, + ), + ) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + validateToken ValidateToken + options []Option + method string + token string + wantToken interface{} + wantStatusCode int + wantBody string + }{ + { + name: "happy path", + validateToken: validator.ValidateToken, + token: validToken, + wantToken: validContextToken, + wantStatusCode: http.StatusOK, + wantBody: `{"message":"Authenticated."}`, + }, + { + name: "validate on options", + validateToken: validator.ValidateToken, + method: http.MethodOptions, + token: validToken, + wantToken: validContextToken, + wantStatusCode: http.StatusOK, + wantBody: `{"message":"Authenticated."}`, + }, + { + name: "bad token format", + token: "bad", + wantStatusCode: http.StatusInternalServerError, + wantBody: `{"message":"Something went wrong while checking the JWT."}`, + }, + { + name: "credentials not optional", + token: "", + wantStatusCode: http.StatusBadRequest, + wantBody: `{"message":"JWT is missing."}`, + }, + { + name: "validate token errors", + validateToken: validator.ValidateToken, + token: invalidToken, + wantStatusCode: http.StatusUnauthorized, + wantBody: `{"message":"JWT is invalid."}`, + }, + { + name: "validateOnOptions set to false", + options: []Option{ + WithValidateOnOptions(false), + }, + method: http.MethodOptions, + token: validToken, + wantStatusCode: http.StatusOK, + wantBody: `{"message":"Authenticated."}`, + }, + { + name: "tokenExtractor errors", + options: []Option{ + WithTokenExtractor(func(r *http.Request) (string, error) { + return "", errors.New("token extractor error") + }), + }, + wantStatusCode: http.StatusInternalServerError, + wantBody: `{"message":"Something went wrong while checking the JWT."}`, + }, + { + name: "credentialsOptional true", + options: []Option{ + WithCredentialsOptional(true), + WithTokenExtractor(func(r *http.Request) (string, error) { + return "", nil + }), + }, + wantStatusCode: http.StatusOK, + wantBody: `{"message":"Authenticated."}`, + }, + { + name: "credentialsOptional false", + options: []Option{ + WithCredentialsOptional(false), + WithTokenExtractor(func(r *http.Request) (string, error) { + return "", nil + }), + }, + wantStatusCode: http.StatusBadRequest, + wantBody: `{"message":"JWT is missing."}`, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + if testCase.method == "" { + testCase.method = http.MethodGet + } + + middleware := New(testCase.validateToken, testCase.options...) + + var actualContextToken interface{} + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + actualContextToken = r.Context().Value(ContextKey{}) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"message":"Authenticated."}`)) + }) + + testServer := httptest.NewServer(middleware.CheckJWT(testHandler)) + defer testServer.Close() + + request, err := http.NewRequest(testCase.method, testServer.URL, nil) + if err != nil { + t.Fatal(err) + } + + if testCase.token != "" { + request.Header.Add("Authorization", testCase.token) + } + + response, err := testServer.Client().Do(request) + if err != nil { + t.Fatal(err) + } + + body, err := ioutil.ReadAll(response.Body) + if err != nil { + t.Fatal(err) + } + defer response.Body.Close() + + if want, got := testCase.wantStatusCode, response.StatusCode; want != got { + t.Fatalf("want status code %d, got %d", want, got) + } + + if want, got := "application/json", response.Header.Get("Content-Type"); want != got { + t.Fatalf("want Content-Type %s, got %s", want, got) + } + + if want, got := testCase.wantBody, string(body); !cmp.Equal(want, got) { + t.Fatal(cmp.Diff(want, got)) + } + + if want, got := testCase.wantToken, actualContextToken; !cmp.Equal(want, got) { + t.Fatal(cmp.Diff(want, got)) + } + }) + } +} + +func Test_invalidError(t *testing.T) { + t.Run("Is", func(t *testing.T) { + e := invalidError{details: errors.New("error details")} + + if !errors.Is(&e, ErrJWTInvalid) { + t.Fatal("expected invalidError to be ErrJWTInvalid via errors.Is, but it was not") + } + }) + + t.Run("Error", func(t *testing.T) { + e := invalidError{details: errors.New("error details")} + + mustErrorMsg(t, "jwt invalid: error details", &e) + }) + + t.Run("Unwrap", func(t *testing.T) { + expectedErr := errors.New("expected err") + e := invalidError{details: expectedErr} + + // under the hood errors.Is is unwrapping the invalidError via + // Unwrap(). + if !errors.Is(&e, expectedErr) { + t.Fatal("expected invalidError to be expectedErr via errors.Is, but it was not") + } + }) +} + +func Test_MultiTokenExtractor(t *testing.T) { + t.Run("uses first extractor that replies", func(t *testing.T) { + wantToken := "i am token" + + exNothing := func(r *http.Request) (string, error) { + return "", nil + } + exSomething := func(r *http.Request) (string, error) { + return wantToken, nil + } + exFail := func(r *http.Request) (string, error) { + return "", errors.New("should not have hit me") + } + + ex := MultiTokenExtractor(exNothing, exSomething, exFail) + + gotToken, err := ex(&http.Request{}) + mustErrorMsg(t, "", err) + + if wantToken != gotToken { + t.Fatalf("wanted token: %q, got: %q", wantToken, gotToken) + } + }) + + t.Run("stops when an extractor fails", func(t *testing.T) { + wantErr := "extraction fail" + + exNothing := func(r *http.Request) (string, error) { + return "", nil + } + exFail := func(r *http.Request) (string, error) { + return "", errors.New(wantErr) + } + + ex := MultiTokenExtractor(exNothing, exFail) + + gotToken, err := ex(&http.Request{}) + mustErrorMsg(t, wantErr, err) + + if gotToken != "" { + t.Fatalf("did not want a token but got: %q", gotToken) + } + }) + + t.Run("defaults to empty", func(t *testing.T) { + exNothing := func(r *http.Request) (string, error) { + return "", nil + } + + ex := MultiTokenExtractor(exNothing, exNothing, exNothing) + + gotToken, err := ex(&http.Request{}) + mustErrorMsg(t, "", err) + + if "" != gotToken { + t.Fatalf("wanted empty token but got: %q", gotToken) + } + }) +} + +func Test_ParameterTokenExtractor(t *testing.T) { + wantToken := "i am token" + param := "i-am-param" + + u, err := url.Parse(fmt.Sprintf("http://localhost?%s=%s", param, wantToken)) + mustErrorMsg(t, "", err) + r := &http.Request{URL: u} + + ex := ParameterTokenExtractor(param) + + gotToken, err := ex(r) + mustErrorMsg(t, "", err) + + if wantToken != gotToken { + t.Fatalf("wanted token: %q, got: %q", wantToken, gotToken) + } +} + +func Test_AuthHeaderTokenExtractor(t *testing.T) { + testCases := []struct { + name string + request *http.Request + wantToken string + wantError string + }{ + { + name: "empty / no header", + request: &http.Request{}, + }, + { + name: "token in header", + request: &http.Request{Header: http.Header{"Authorization": []string{fmt.Sprintf("Bearer %s", "i-am-token")}}}, + wantToken: "i-am-token", + }, + { + name: "no bearer", + request: &http.Request{Header: http.Header{"Authorization": []string{"i-am-token"}}}, + wantError: "Authorization header format must be Bearer {token}", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + gotToken, gotError := AuthHeaderTokenExtractor(testCase.request) + mustErrorMsg(t, testCase.wantError, gotError) + + if testCase.wantToken != gotToken { + t.Fatalf("wanted token: %q, got: %q", testCase.wantToken, gotToken) + } + }) + } +} + +func Test_CookieTokenExtractor(t *testing.T) { + testCases := []struct { + name string + cookie *http.Cookie + wantToken string + wantError string + }{ + { + name: "no cookie", + wantError: "http: named cookie not present", + }, + { + name: "token in cookie", + cookie: &http.Cookie{Name: "token", Value: "i-am-token"}, + wantToken: "i-am-token", + }, + { + name: "empty cookie", + cookie: &http.Cookie{Name: "token"}, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com", nil) + + if testCase.cookie != nil { + req.AddCookie(testCase.cookie) + } + + gotToken, gotError := CookieTokenExtractor("token")(req) + mustErrorMsg(t, testCase.wantError, gotError) + + if testCase.wantToken != gotToken { + t.Fatalf("wanted token: %q, got: %q", testCase.wantToken, gotToken) + } + }) + } +} + +func mustErrorMsg(t testing.TB, want string, got error) { + if (want == "" && got != nil) || + (want != "" && (got == nil || got.Error() != want)) { + t.Fatalf("want error: %s, got %v", want, got) + } +} diff --git a/option.go b/option.go new file mode 100644 index 00000000..3c0b6c6d --- /dev/null +++ b/option.go @@ -0,0 +1,46 @@ +package jwtmiddleware + +// Option is how options for the JWTMiddleware are set up. +type Option func(*JWTMiddleware) + +// WithCredentialsOptional sets up if credentials are +// optional or not. If set to true then an empty token +// will be considered valid. +// +// Default value: false. +func WithCredentialsOptional(value bool) Option { + return func(m *JWTMiddleware) { + m.credentialsOptional = value + } +} + +// WithValidateOnOptions sets up if OPTIONS requests +// should have their JWT validated or not. +// +// Default value: true. +func WithValidateOnOptions(value bool) Option { + return func(m *JWTMiddleware) { + m.validateOnOptions = value + } +} + +// WithErrorHandler sets the handler which is called +// when we encounter errors in the JWTMiddleware. +// See the ErrorHandler type for more information. +// +// Default value: DefaultErrorHandler. +func WithErrorHandler(h ErrorHandler) Option { + return func(m *JWTMiddleware) { + m.errorHandler = h + } +} + +// WithTokenExtractor sets up the function which extracts +// the JWT to be validated from the request. +// +// Default value: AuthHeaderTokenExtractor. +func WithTokenExtractor(e TokenExtractor) Option { + return func(m *JWTMiddleware) { + m.tokenExtractor = e + } +} diff --git a/validate/josev2/doc.go b/validate/josev2/doc.go new file mode 100644 index 00000000..45ac05ac --- /dev/null +++ b/validate/josev2/doc.go @@ -0,0 +1,18 @@ +/* +Package josev2 contains an implementation of jwtmiddleware.ValidateToken using +the Square go-jose package version 2. + +The implementation handles some nuances around JWTs and supports: +- a key func to pull the key(s) used to verify the token signature +- verifying the signature algorithm is what it should be +- validation of "regular" claims +- validation of custom claims +- clock skew allowances + +When this package is used, tokens are returned as `JSONWebToken` from the +gopkg.in/square/go-jose.v2/jwt package. + +Note that while the jose package does support multi-recipient JWTs, this +package does not support them. +*/ +package josev2 diff --git a/validate/josev2/examples/README.md b/validate/josev2/examples/README.md new file mode 100644 index 00000000..d8ba97b3 --- /dev/null +++ b/validate/josev2/examples/README.md @@ -0,0 +1,87 @@ +# josev2 examples + +These examples should get you up and running and understanding how to best use +the validator. + +You will need `jwt-cli` to work through the examples: +``` +npm install --global "@clarketm/jwt-cli" +``` + +In in terminal, run the example to get started: +``` +go run main.go +``` +Now you can follow the examples below. + +### with clockskew +The example allows clock skew of 30 seconds. Let's use a token that expired 45 +seconds ago to show that it will reject this. +``` +export TOKEN=$(jwt sign -n "{\"iat\":$(date -r $(( $(date +%s) - 3645 )) +%s),\"iss\":\"josev2-example\"}" "secret") +curl "127.0.0.1:3000" -H "Authorization: Bearer $TOKEN" +``` + +Now lets generate a token that expired 15 seconds ago and watch as it is not +rejected. +``` +export TOKEN=$(jwt sign -n "{\"iat\":$(date -r $(( $(date +%s) - 3615 )) +%s),\"iss\":\"josev2-example\"}" "secret") +curl "127.0.0.1:3000" -H "Authorization: Bearer $TOKEN" +``` + +### custom claims +We can use custom claims in our token and have the validator pass them back to +us in the user context. When the endpoint responds after a valid request it +prints out the CustomClaims. Let's add two claims to our token to see that it +handles the claim we have defined in CustomClaimsExample but does nothing with +the claim we do not have defined. +``` +export TOKEN=$(jwt sign -n "{\"username\":\"user123\",\"hairColor\":\"brown\",\"iss\":\"josev2-example\"}" "secret") +curl "127.0.0.1:3000" -H "Authorization: Bearer $TOKEN" +``` +It will print out something like +```json +{ + "CustomClaims": { + "username": "user123" + }, + "Claims": { + "iss": "josev2-example", + "exp": 1616801896, + "iat": 1616798296 + } +} +``` +As you can see the `username` claim is there, but the `hairColor` claim is not. + +### custom validaton +Along with custom claims we can also run custom validation logic to determine +if the token should be rejected or not. Our example is setup to reject anything +that has the field `shouldReject` set to `true`. +``` +export TOKEN=$(jwt sign -n "{\"shouldReject\":true,\"iss\":\"josev2-example\"}" "secret") +curl "127.0.0.1:3000" -H "Authorization: Bearer $TOKEN" +``` +It will print out something like +``` +The token isn't valid: custom claims not validated: should reject was set to true +``` +The message comes directly from the custom validation! + +### expected claims +In all of the above examples we've seen the `iss` field being set. That's +because it expects the issuer to be `josev2-example`. This validation is built +right into jose. If we remove the field it will error on that field. +``` +export TOKEN=$(jwt sign -n "{}" "secret") +curl "127.0.0.1:3000" -H "Authorization: Bearer $TOKEN" +``` +It will print out something like +``` +The token isn't valid: expected claims not validated: square/go-jose/jwt: validation failed, invalid issuer claim (iss) +``` + +### JWKS +For a JWKS example please see [examples/http-jwks-example/README.md](../../../examples/http-jwks-example/README.md). + +Take a look through the example code and things will make a lot more sense. diff --git a/validate/josev2/examples/main.go b/validate/josev2/examples/main.go new file mode 100644 index 00000000..11a2b014 --- /dev/null +++ b/validate/josev2/examples/main.go @@ -0,0 +1,110 @@ +package main + +import ( + "context" + "encoding/json" + "log" + "net/http" + "time" + + "github.com/pkg/errors" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" + + "github.com/auth0/go-jwt-middleware" + "github.com/auth0/go-jwt-middleware/validate/josev2" +) + +var handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims := r.Context().Value(jwtmiddleware.ContextKey{}).(*josev2.UserContext) + + payload, err := json.Marshal(claims) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write(payload) +}) + +// CustomClaimsExample contains custom data we want from the token. +type CustomClaimsExample struct { + Name string `json:"name"` + Username string `json:"username"` + ShouldReject bool `json:"shouldReject,omitempty"` +} + +// Validate does nothing for this example. +func (c *CustomClaimsExample) Validate(ctx context.Context) error { + if c.ShouldReject { + return errors.New("should reject was set to true") + } + return nil +} + +func main() { + keyFunc := func(ctx context.Context) (interface{}, error) { + // Our token must be signed using this data. + return []byte("secret"), nil + } + + expectedClaims := func() jwt.Expected { + // By setting up expected claims we are saying + // a token must have the data we specify. + return jwt.Expected{ + Issuer: "josev2-example", + Time: time.Now(), + } + } + + customClaims := func() josev2.CustomClaims { + // We want this struct to be filled in with + // our custom claims from the token. + return &CustomClaimsExample{} + } + + // Set up the josev2 validator. + validator, err := josev2.New( + keyFunc, + jose.HS256, + josev2.WithExpectedClaims(expectedClaims), + josev2.WithCustomClaims(customClaims), + josev2.WithAllowedClockSkew(30*time.Second), + ) + if err != nil { + log.Fatalf("failed to set up the josev2 validator: %v", err) + } + + // Set up the middleware. + middleware := jwtmiddleware.New(validator.ValidateToken) + + http.ListenAndServe("0.0.0.0:3000", middleware.CheckJWT(handler)) + + // Try it out with: + // + // eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJqb3NldjItZXhhbXBsZSIsInN1YiI6IjEyMzQ1Njc4OTAiLCJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE1MTYyMzkwMjIsInVzZXJuYW1lIjoidXNlcjEyMyJ9.1v7S4aF7lVM92bRZ8tVTrKGZ6FwkX-7ybZQA5A7mq8E + // + // which is signed with 'secret' and has the data: + // { + // "iss": "josev2-example", + // "sub": "1234567890", + // "name": "John Doe", + // "iat": 1516239022, + // "username": "user123" + // } + // + // You can also try out the custom validation with: + // + // eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJqb3NldjItZXhhbXBsZSIsInN1YiI6IjEyMzQ1Njc4OTAiLCJuYW1lIjoiSm9obiBEb2UiLCJpYXQiOjE1MTYyMzkwMjIsInVzZXJuYW1lIjoidXNlcjEyMyIsInNob3VsZFJlamVjdCI6dHJ1ZX0.vy-dBpmjnULan2TIHSnGCv-e7Az_mF9yNUe07qf3t8w + // + // which is signed with 'secret' and has the data: + // { + // "iss": "josev2-example", + // "sub": "1234567890", + // "name": "John Doe", + // "iat": 1516239022, + // "username": "user123", + // "shouldReject": true + // } +} diff --git a/validate/josev2/josev2.go b/validate/josev2/josev2.go new file mode 100644 index 00000000..66f044f7 --- /dev/null +++ b/validate/josev2/josev2.go @@ -0,0 +1,148 @@ +package josev2 + +import ( + "context" + "fmt" + "time" + + "github.com/pkg/errors" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" +) + +// Validator to use with the jose v2 package. +type Validator struct { + keyFunc func(context.Context) (interface{}, error) // Required. + signatureAlgorithm jose.SignatureAlgorithm // Required. + expectedClaims func() jwt.Expected // Optional. + customClaims func() CustomClaims // Optional. + allowedClockSkew time.Duration // Optional. +} + +// Option is how options for the Validator are set up. +type Option func(*Validator) + +// CustomClaims defines any custom data / claims wanted. +// The Validator will call the Validate function which +// is where custom validation logic can be defined. +type CustomClaims interface { + Validate(context.Context) error +} + +// UserContext is the struct that will be inserted into +// the context for the user. CustomClaims will be nil +// unless WithCustomClaims is passed to New. +type UserContext struct { + CustomClaims CustomClaims + RegisteredClaims jwt.Claims +} + +// New sets up a new Validator with the required keyFunc +// and signatureAlgorithm as well as custom options. +func New( + keyFunc func(context.Context) (interface{}, error), + signatureAlgorithm jose.SignatureAlgorithm, + opts ...Option, +) (*Validator, error) { + if keyFunc == nil { + return nil, errors.New("keyFunc is required but was nil") + } + + v := &Validator{ + allowedClockSkew: 0, + keyFunc: keyFunc, + signatureAlgorithm: signatureAlgorithm, + customClaims: nil, + expectedClaims: func() jwt.Expected { + return jwt.Expected{ + Time: time.Now(), + } + }, + } + + for _, opt := range opts { + opt(v) + } + + return v, nil +} + +// WithAllowedClockSkew is an option which sets up the allowed +// clock skew for the token. Note that in order to use this +// the expected claims Time field MUST not be time.IsZero(). +// If this option is not used clock skew is not allowed. +func WithAllowedClockSkew(skew time.Duration) Option { + return func(v *Validator) { + v.allowedClockSkew = skew + } +} + +// WithCustomClaims sets up a function that returns the object +// CustomClaims that will be unmarshalled into and on which +// Validate is called on for custom validation. If this option +// is not used the Validator will do nothing for custom claims. +func WithCustomClaims(f func() CustomClaims) Option { + return func(v *Validator) { + v.customClaims = f + } +} + +// WithExpectedClaims sets up a function that returns the object +// used to validate claims. If this option is not used a default +// jwt.Expected object is used which only validates token time. +func WithExpectedClaims(f func() jwt.Expected) Option { + return func(v *Validator) { + v.expectedClaims = f + } +} + +// ValidateToken validates the passed in JWT using the jose v2 package. +func (v *Validator) ValidateToken(ctx context.Context, tokenString string) (interface{}, error) { + token, err := jwt.ParseSigned(tokenString) + if err != nil { + return nil, fmt.Errorf("could not parse the token: %w", err) + } + + signatureAlgorithm := string(v.signatureAlgorithm) + + // If jwt.ParseSigned did not error there will always be at least one header in the token. + if signatureAlgorithm != "" && signatureAlgorithm != token.Headers[0].Algorithm { + return nil, fmt.Errorf( + "expected %q signing algorithm but token specified %q", + signatureAlgorithm, + token.Headers[0].Algorithm, + ) + } + + key, err := v.keyFunc(ctx) + if err != nil { + return nil, fmt.Errorf("error getting the keys from the key func: %w", err) + } + + claimDest := []interface{}{&jwt.Claims{}} + if v.customClaims != nil { + claimDest = append(claimDest, v.customClaims()) + } + + if err = token.Claims(key, claimDest...); err != nil { + return nil, fmt.Errorf("could not get token claims: %w", err) + } + + userCtx := &UserContext{ + CustomClaims: nil, + RegisteredClaims: *claimDest[0].(*jwt.Claims), + } + + if err = userCtx.RegisteredClaims.ValidateWithLeeway(v.expectedClaims(), v.allowedClockSkew); err != nil { + return nil, fmt.Errorf("expected claims not validated: %w", err) + } + + if v.customClaims != nil { + userCtx.CustomClaims = claimDest[1].(CustomClaims) + if err = userCtx.CustomClaims.Validate(ctx); err != nil { + return nil, fmt.Errorf("custom claims not validated: %w", err) + } + } + + return userCtx, nil +} diff --git a/validate/josev2/josev2_test.go b/validate/josev2/josev2_test.go new file mode 100644 index 00000000..a3f1a46b --- /dev/null +++ b/validate/josev2/josev2_test.go @@ -0,0 +1,153 @@ +package josev2 + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/pkg/errors" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" +) + +type testingCustomClaims struct { + Subject string + ReturnError error +} + +func (tcc *testingCustomClaims) Validate(ctx context.Context) error { + return tcc.ReturnError +} + +func equalErrors(actual error, expected string) bool { + if actual == nil { + return expected == "" + } + return actual.Error() == expected +} + +func Test_Validate(t *testing.T) { + testCases := []struct { + name string + signatureAlgorithm jose.SignatureAlgorithm + token string + keyFuncReturnError error + customClaims CustomClaims + expectedClaims jwt.Expected + expectedError string + expectedContext *UserContext + }{ + { + name: "happy path", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.Rq8IxqeX7eA6GgYxlcHdPFVRNFFZc5rEI3MQTZZbK3I`, + expectedContext: &UserContext{ + RegisteredClaims: jwt.Claims{Subject: "1234567890"}, + }, + }, + { + // we want to test that when it expects RSA but we send + // HMAC encrypted with the server public key it will + // error + name: "errors on wrong algorithm", + signatureAlgorithm: jose.PS256, + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.XbPfbIHMI6arZ3Y922BhjWgQzWXcXNrz0ogtVhfEd2o`, + expectedError: "expected \"PS256\" signing algorithm but token specified \"HS256\"", + }, + { + name: "errors when jwt.ParseSigned errors", + expectedError: "could not parse the token: square/go-jose: compact JWS format must have three parts", + }, + { + name: "errors when the key func errors", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.XbPfbIHMI6arZ3Y922BhjWgQzWXcXNrz0ogtVhfEd2o`, + keyFuncReturnError: errors.New("key func error message"), + expectedError: "error getting the keys from the key func: key func error message", + }, + { + name: "errors when tok.Claims errors", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.hDyICUnkCrwFJnkJHRSkwMZNSYZ9LI6z2EFJdtwFurA`, + expectedError: "could not get token claims: square/go-jose: error in cryptographic primitive", + }, + { + name: "errors when expected claims errors", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.XbPfbIHMI6arZ3Y922BhjWgQzWXcXNrz0ogtVhfEd2o`, + expectedClaims: jwt.Expected{Subject: "wrong subject"}, + expectedError: "expected claims not validated: square/go-jose/jwt: validation failed, invalid subject claim (sub)", + }, + { + name: "errors when custom claims errors", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.XbPfbIHMI6arZ3Y922BhjWgQzWXcXNrz0ogtVhfEd2o`, + customClaims: &testingCustomClaims{ReturnError: errors.New("custom claims error message")}, + expectedError: "custom claims not validated: custom claims error message", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + var customClaimsFunc func() CustomClaims + if testCase.customClaims != nil { + customClaimsFunc = func() CustomClaims { return testCase.customClaims } + } + + v, _ := New(func(ctx context.Context) (interface{}, error) { return []byte("secret"), testCase.keyFuncReturnError }, + testCase.signatureAlgorithm, + WithExpectedClaims(func() jwt.Expected { return testCase.expectedClaims }), + WithCustomClaims(customClaimsFunc), + ) + actualContext, err := v.ValidateToken(context.Background(), testCase.token) + if !equalErrors(err, testCase.expectedError) { + t.Fatalf("wanted err:\n%s\ngot:\n%+v\n", testCase.expectedError, err) + } + + if (testCase.expectedContext == nil && actualContext != nil) || (testCase.expectedContext != nil && actualContext == nil) { + t.Fatalf("wanted user context:\n%+v\ngot:\n%+v\n", testCase.expectedContext, actualContext) + } else if testCase.expectedContext != nil { + if diff := cmp.Diff(testCase.expectedContext, actualContext.(*UserContext)); diff != "" { + t.Errorf("user context mismatch (-want +got):\n%s", diff) + } + } + }) + } +} + +func Test_New(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + keyFunc := func(ctx context.Context) (interface{}, error) { return nil, nil } + customClaims := func() CustomClaims { return nil } + + v, err := New(keyFunc, jose.HS256, WithCustomClaims(customClaims)) + + if !equalErrors(err, "") { + t.Fatalf("wanted err:\n%s\ngot:\n%+v\n", "", err) + } + + if v.allowedClockSkew != 0 { + t.Logf("expected allowedClockSkew to be 0 but it was %d", v.allowedClockSkew) + t.Fail() + } + + if v.keyFunc == nil { + t.Log("keyFunc was nil when it should not have been") + t.Fail() + } + + if v.signatureAlgorithm != jose.HS256 { + t.Logf("signatureAlgorithm was %q when it should have been %q", v.signatureAlgorithm, jose.HS256) + t.Fail() + } + + if v.customClaims == nil { + t.Log("customClaims was nil when it should not have been") + t.Fail() + } + }) + + t.Run("error on no keyFunc", func(t *testing.T) { + _, err := New(nil, jose.HS256) + + expectedErr := "keyFunc is required but was nil" + if !equalErrors(err, expectedErr) { + t.Fatalf("wanted err:\n%s\ngot:\n%+v\n", expectedErr, err) + } + }) +} diff --git a/validate/josev2/jwks_provider.go b/validate/josev2/jwks_provider.go new file mode 100644 index 00000000..22e10e6a --- /dev/null +++ b/validate/josev2/jwks_provider.go @@ -0,0 +1,123 @@ +package josev2 + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "sync" + "time" + + "gopkg.in/square/go-jose.v2" + + "github.com/auth0/go-jwt-middleware/internal/oidc" +) + +// JWKSProvider handles getting JWKS from the specified IssuerURL and exposes +// KeyFunc which adheres to the keyFunc signature that the Validator requires. +// Most likely you will want to use the CachingJWKSProvider as it handles +// getting and caching JWKS which can help reduce request time and potential +// rate limiting from your provider. +type JWKSProvider struct { + IssuerURL url.URL +} + +// NewJWKSProvider builds and returns a new *JWKSProvider. +func NewJWKSProvider(issuerURL url.URL) *JWKSProvider { + return &JWKSProvider{IssuerURL: issuerURL} +} + +// KeyFunc adheres to the keyFunc signature that the Validator requires. +// While it returns an interface to adhere to keyFunc, as long as the +// error is nil the type will be *jose.JSONWebKeySet. +func (p *JWKSProvider) KeyFunc(ctx context.Context) (interface{}, error) { + wkEndpoints, err := oidc.GetWellKnownEndpointsFromIssuerURL(ctx, p.IssuerURL) + if err != nil { + return nil, err + } + + jwksURI, err := url.Parse(wkEndpoints.JWKSURI) + if err != nil { + return nil, fmt.Errorf("could not parse JWKS URI from well known endpoints: %w", err) + } + + request, err := http.NewRequest(http.MethodGet, jwksURI.String(), nil) + if err != nil { + return nil, fmt.Errorf("could not build request to get JWKS: %w", err) + } + request = request.WithContext(ctx) + + response, err := http.DefaultClient.Do(request) + if err != nil { + return nil, err + } + defer response.Body.Close() + + var jwks jose.JSONWebKeySet + if err := json.NewDecoder(response.Body).Decode(&jwks); err != nil { + return nil, fmt.Errorf("could not decode jwks: %w", err) + } + + return &jwks, nil +} + +// CachingJWKSProvider handles getting JWKS from the specified IssuerURL +// and caching them for CacheTTL time. It exposes KeyFunc which adheres +// to the keyFunc signature that the Validator requires. +type CachingJWKSProvider struct { + IssuerURL url.URL + CacheTTL time.Duration + mu sync.Mutex + cache map[string]cachedJWKS +} + +type cachedJWKS struct { + jwks *jose.JSONWebKeySet + expiresAt time.Time +} + +// NewCachingJWKSProvider builds and returns a new CachingJWKSProvider. +// If cacheTTL is zero then a default value of 1 minute will be used. +func NewCachingJWKSProvider(issuerURL url.URL, cacheTTL time.Duration) *CachingJWKSProvider { + if cacheTTL == 0 { + cacheTTL = 1 * time.Minute + } + + return &CachingJWKSProvider{ + IssuerURL: issuerURL, + CacheTTL: cacheTTL, + cache: map[string]cachedJWKS{}, + } +} + +// KeyFunc adheres to the keyFunc signature that the Validator requires. +// While it returns an interface to adhere to keyFunc, as long as the +// error is nil the type will be *jose.JSONWebKeySet. +func (c *CachingJWKSProvider) KeyFunc(ctx context.Context) (interface{}, error) { + issuer := c.IssuerURL.Hostname() + + c.mu.Lock() + defer func() { + c.mu.Unlock() + }() + + if cached, ok := c.cache[issuer]; ok { + if !time.Now().After(cached.expiresAt) { + return cached.jwks, nil + } + } + + provider := JWKSProvider{IssuerURL: c.IssuerURL} + jwks, err := provider.KeyFunc(ctx) + if err != nil { + return nil, err + } + + c.cache[issuer] = cachedJWKS{ + jwks: jwks.(*jose.JSONWebKeySet), + expiresAt: time.Now().Add(c.CacheTTL), + } + + return jwks, nil +} diff --git a/validate/josev2/jwks_provider_test.go b/validate/josev2/jwks_provider_test.go new file mode 100644 index 00000000..5ce89514 --- /dev/null +++ b/validate/josev2/jwks_provider_test.go @@ -0,0 +1,285 @@ +package josev2 + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "math/big" + "net/http" + "net/http/httptest" + "net/url" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "gopkg.in/square/go-jose.v2" + + "github.com/auth0/go-jwt-middleware/internal/oidc" +) + +func Test_JWKSProvider(t *testing.T) { + var ( + p CachingJWKSProvider + server *httptest.Server + responseBytes []byte + responseStatusCode, reqCount int + serverURL *url.URL + ) + + tests := []struct { + name string + main func(t *testing.T) + }{ + { + name: "calls out to well known endpoint", + main: func(t *testing.T) { + _, jwks := genValidRSAKeyAndJWKS(t) + var err error + responseBytes, err = json.Marshal(jwks) + if !equalErrors(err, "") { + t.Fatalf("did not want an error, but got %s", err) + } + + _, err = p.KeyFunc(context.TODO()) + if !equalErrors(err, "") { + t.Fatalf("did not want an error, but got %s", err) + } + }, + }, + { + name: "errors if it can't decode the jwks", + main: func(t *testing.T) { + responseBytes = []byte("<>") + _, err := p.KeyFunc(context.TODO()) + + wantErr := "could not decode jwks: invalid character '<' looking for beginning of value" + if !equalErrors(err, wantErr) { + t.Fatalf("wanted err:\n%s\ngot:\n%+v\n", wantErr, err) + } + }, + }, + { + name: "passes back the valid jwks", + main: func(t *testing.T) { + _, jwks := genValidRSAKeyAndJWKS(t) + var err error + responseBytes, err = json.Marshal(jwks) + if !equalErrors(err, "") { + t.Fatalf("did not want an error, but got %s", err) + } + + p.CacheTTL = time.Minute * 5 + actualJWKS, err := p.KeyFunc(context.TODO()) + if !equalErrors(err, "") { + t.Fatalf("did not want an error, but got %s", err) + } + + if want, got := &jwks, actualJWKS; !cmp.Equal(want, got) { + t.Fatalf("jwks did not match: %s", cmp.Diff(want, got)) + } + + if want, got := &jwks, p.cache[serverURL.Hostname()].jwks; !cmp.Equal(want, got) { + t.Fatalf("cached jwks did not match: %s", cmp.Diff(want, got)) + } + + expiresAt := p.cache[serverURL.Hostname()].expiresAt + if !time.Now().Before(expiresAt) { + t.Fatalf("wanted cache item expiration to be in the future but it was not: %s", expiresAt) + } + }, + }, + { + name: "returns the cached jwks when they are not expired", + main: func(t *testing.T) { + _, expectedCachedJWKS := genValidRSAKeyAndJWKS(t) + p.cache[serverURL.Hostname()] = cachedJWKS{ + jwks: &expectedCachedJWKS, + expiresAt: time.Now().Add(1 * time.Minute), + } + + actualJWKS, err := p.KeyFunc(context.TODO()) + if !equalErrors(err, "") { + t.Fatalf("did not want an error, but got %s", err) + } + + if want, got := &expectedCachedJWKS, actualJWKS; !cmp.Equal(want, got) { + t.Fatalf("cached jwks did not match: %s", cmp.Diff(want, got)) + } + + if reqCount > 0 { + t.Fatalf("did not want any requests since we should have read from the cache, but we got %d requests", reqCount) + } + }, + }, + { + name: "re-caches the jwks if they have expired", + main: func(t *testing.T) { + _, expiredCachedJWKS := genValidRSAKeyAndJWKS(t) + expiresAt := time.Now().Add(-10 * time.Minute) + p.cache[server.URL] = cachedJWKS{ + jwks: &expiredCachedJWKS, + expiresAt: expiresAt, + } + _, jwks := genValidRSAKeyAndJWKS(t) + var err error + responseBytes, err = json.Marshal(jwks) + if !equalErrors(err, "") { + t.Fatalf("did not want an error, but got %s", err) + } + + p.CacheTTL = time.Minute * 5 + actualJWKS, err := p.KeyFunc(context.TODO()) + if !equalErrors(err, "") { + t.Fatalf("did not want an error, but got %s", err) + } + + if want, got := &jwks, actualJWKS; !cmp.Equal(want, got) { + t.Fatalf("jwks did not match: %s", cmp.Diff(want, got)) + } + + if want, got := &jwks, p.cache[serverURL.Hostname()].jwks; !cmp.Equal(want, got) { + t.Fatalf("cached jwks did not match: %s", cmp.Diff(want, got)) + } + + cacheExpiresAt := p.cache[serverURL.Hostname()].expiresAt + if !time.Now().Before(cacheExpiresAt) { + t.Fatalf("wanted cache item expiration to be in the future but it was not: %s", cacheExpiresAt) + } + }, + }, + { + name: "only calls the API once when multiple requests come in", + main: func(t *testing.T) { + _, jwks := genValidRSAKeyAndJWKS(t) + var err error + responseBytes, err = json.Marshal(jwks) + if !equalErrors(err, "") { + t.Fatalf("did not want an error, but got %s", err) + } + + p.CacheTTL = time.Minute * 5 + + wg := sync.WaitGroup{} + for i := 0; i < 50; i++ { + wg.Add(1) + go func(t *testing.T) { + actualJWKS, err := p.KeyFunc(context.TODO()) + if !equalErrors(err, "") { + t.Errorf("did not want an error, but got %s", err) + } + + if want, got := &jwks, actualJWKS; !cmp.Equal(want, got) { + t.Errorf("jwks did not match: %s", cmp.Diff(want, got)) + } + + wg.Done() + }(t) + } + wg.Wait() + + actualJWKS, err := p.KeyFunc(context.TODO()) + if !equalErrors(err, "") { + t.Fatalf("did not want an error, but got %s", err) + } + + if want, got := &jwks, actualJWKS; !cmp.Equal(want, got) { + t.Fatalf("jwks did not match: %s", cmp.Diff(want, got)) + } + + if reqCount != 2 { + t.Fatalf("only wanted 2 requests (well known and jwks) , but we got %d requests", reqCount) + } + + if want, got := &jwks, p.cache[serverURL.Hostname()].jwks; !cmp.Equal(want, got) { + t.Fatalf("cached jwks did not match: %s", cmp.Diff(want, got)) + } + + cacheExpiresAt := p.cache[serverURL.Hostname()].expiresAt + if !time.Now().Before(cacheExpiresAt) { + t.Fatalf("wanted cache item expiration to be in the future but it was not: %s", cacheExpiresAt) + } + }, + }, + } + + for _, test := range tests { + var reqCallMutex sync.Mutex + + reqCount = 0 + responseBytes = []byte(`{"kid":""}`) + responseStatusCode = http.StatusOK + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // handle mutex things + reqCallMutex.Lock() + defer reqCallMutex.Unlock() + reqCount++ + w.WriteHeader(responseStatusCode) + + switch r.URL.String() { + case "/.well-known/openid-configuration": + wk := oidc.WellKnownEndpoints{JWKSURI: server.URL + "/url_for_jwks"} + err := json.NewEncoder(w).Encode(wk) + if !equalErrors(err, "") { + t.Fatalf("did not want an error, but got %s", err) + } + case "/url_for_jwks": + _, err := w.Write(responseBytes) + if !equalErrors(err, "") { + t.Fatalf("did not want an error, but got %s", err) + } + default: + t.Fatalf("do not know how to handle url %s", r.URL.String()) + } + })) + defer server.Close() + serverURL = mustParseURL(server.URL) + + p = CachingJWKSProvider{ + IssuerURL: *serverURL, + CacheTTL: 0, + cache: map[string]cachedJWKS{}, + } + + t.Run(test.name, test.main) + } +} + +func mustParseURL(toParse string) *url.URL { + parsed, err := url.Parse(toParse) + if err != nil { + panic(err) + } + + return parsed +} + +func genValidRSAKeyAndJWKS(t *testing.T) (*rsa.PrivateKey, jose.JSONWebKeySet) { + ca := &x509.Certificate{ + SerialNumber: big.NewInt(1653), + } + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + rawCert, err := x509.CreateCertificate(rand.Reader, ca, ca, &priv.PublicKey, priv) + if !equalErrors(err, "") { + t.Fatalf("did not want an error, but got %s", err) + } + + jwks := jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + Key: priv, + KeyID: "kid", + Certificates: []*x509.Certificate{ + { + Raw: rawCert, + }, + }, + CertificateThumbprintSHA1: []uint8{}, + CertificateThumbprintSHA256: []uint8{}, + }, + }, + } + return priv, jwks +} diff --git a/validate/jwt-go/examples/main.go b/validate/jwt-go/examples/main.go new file mode 100644 index 00000000..b0fb3a70 --- /dev/null +++ b/validate/jwt-go/examples/main.go @@ -0,0 +1,95 @@ +package main + +import ( + "context" + "encoding/json" + "log" + "net/http" + + "github.com/golang-jwt/jwt/v4" + "github.com/pkg/errors" + + "github.com/auth0/go-jwt-middleware" + "github.com/auth0/go-jwt-middleware/validate/jwt-go" +) + +var handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims := r.Context().Value(jwtmiddleware.ContextKey{}).(*CustomClaimsExample) + + payload, err := json.Marshal(claims) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write(payload) +}) + +// CustomClaimsExample contains custom data we want from the token. +type CustomClaimsExample struct { + Username string `json:"username"` + ShouldReject bool `json:"shouldReject,omitempty"` + jwt.RegisteredClaims +} + +// Validate does nothing for this example, however we can +// validate in here any expectations we have on our claims. +func (c *CustomClaimsExample) Validate(ctx context.Context) error { + if c.ShouldReject { + return errors.New("should reject was set to true") + } + return nil +} + +func main() { + keyFunc := func(t *jwt.Token) (interface{}, error) { + // Our token must be signed using this data. + return []byte("secret"), nil + } + + customClaims := func() jwtgo.CustomClaims { + // We want this struct to be filled in with + // our custom claims from the token. + return &CustomClaimsExample{} + } + + // Set up the jwt-go validator. + validator, err := jwtgo.New( + keyFunc, + "HS256", + jwtgo.WithCustomClaims(customClaims), + ) + if err != nil { + log.Fatalf("failed to set up the jwt-go validator: %v", err) + } + + // Set up the middleware. + middleware := jwtmiddleware.New(validator.ValidateToken) + + http.ListenAndServe("0.0.0.0:3000", middleware.CheckJWT(handler)) + // Try it out with: + // + // eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJqd3Rnby1leGFtcGxlIiwic3ViIjoiMTIzNDU2Nzg5MCIsImlhdCI6MTUxNjIzOTAyMiwidXNlcm5hbWUiOiJ1c2VyMTIzIn0.ha_JgA29vSAb3HboPRXEi9Dm5zy7ARzd4P8AFoYP9t0 + // + // which is signed with 'secret' and has the data: + // { + // "iss": "jwtgo-example", + // "sub": "1234567890", + // "iat": 1516239022, + // "username": "user123" + // } + // + // You can also try out the custom validation with: + // + // eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJqd3Rnby1leGFtcGxlIiwic3ViIjoiMTIzNDU2Nzg5MCIsImlhdCI6MTUxNjIzOTAyMiwidXNlcm5hbWUiOiJ1c2VyMTIzIiwic2hvdWxkUmVqZWN0Ijp0cnVlfQ.awZ0DFpJ-hH5xn-q-sZHJWj7oTAOkPULwgFO4O6D67o + // + // which is signed with 'secret' and has the data: + // { + // "iss": "jwtgo-example", + // "sub": "1234567890", + // "iat": 1516239022, + // "username": "user123", + // "shouldReject": true + // } +} diff --git a/validate/jwt-go/jwtgo.go b/validate/jwt-go/jwtgo.go new file mode 100644 index 00000000..7105000c --- /dev/null +++ b/validate/jwt-go/jwtgo.go @@ -0,0 +1,86 @@ +package jwtgo + +import ( + "context" + "fmt" + + "github.com/golang-jwt/jwt/v4" + "github.com/pkg/errors" +) + +// Validator to use with the jwt-go package. +type Validator struct { + keyFunc func(*jwt.Token) (interface{}, error) // Required. + signatureAlgorithm string // Required. + customClaims func() CustomClaims // Optional. +} + +// Option is how options for the Validator are set up. +type Option func(*Validator) + +// CustomClaims defines any custom data / claims wanted. +// The Validator will call the Validate function which +// is where custom validation logic can be defined. +type CustomClaims interface { + jwt.Claims + Validate(context.Context) error +} + +// WithCustomClaims sets up a function that returns the object +// CustomClaims that will be unmarshalled into and on which +// Validate is called on for custom validation. If this option +// is not used the Validator will do nothing for custom claims. +func WithCustomClaims(f func() CustomClaims) Option { + return func(v *Validator) { + v.customClaims = f + } +} + +// New sets up a new Validator with the required keyFunc +// and signatureAlgorithm as well as custom options. +func New( + keyFunc jwt.Keyfunc, + signatureAlgorithm string, + opts ...Option, +) (*Validator, error) { + if keyFunc == nil { + return nil, errors.New("keyFunc is required but was nil") + } + + v := &Validator{ + keyFunc: keyFunc, + signatureAlgorithm: signatureAlgorithm, + customClaims: nil, + } + + for _, opt := range opts { + opt(v) + } + + return v, nil +} + +// ValidateToken validates the passed in JWT using the jwt-go package. +func (v *Validator) ValidateToken(ctx context.Context, token string) (interface{}, error) { + var claims jwt.Claims = &jwt.RegisteredClaims{} + if v.customClaims != nil { + claims = v.customClaims() + } + + parser := &jwt.Parser{} + if v.signatureAlgorithm != "" { + parser.ValidMethods = []string{v.signatureAlgorithm} + } + + if _, err := parser.ParseWithClaims(token, claims, v.keyFunc); err != nil { + return nil, fmt.Errorf("could not parse the token: %w", err) + } + + if customClaims, ok := claims.(CustomClaims); ok { + if err := customClaims.Validate(ctx); err != nil { + return nil, fmt.Errorf("custom claims not validated: %w", err) + } + } + + return claims, nil +} diff --git a/validate/jwt-go/jwtgo_test.go b/validate/jwt-go/jwtgo_test.go new file mode 100644 index 00000000..9abdf3de --- /dev/null +++ b/validate/jwt-go/jwtgo_test.go @@ -0,0 +1,140 @@ +package jwtgo + +import ( + "context" + "testing" + + "github.com/golang-jwt/jwt/v4" + "github.com/google/go-cmp/cmp" + "github.com/pkg/errors" +) + +type testingCustomClaims struct { + Foo string `json:"foo"` + ReturnError error + jwt.RegisteredClaims +} + +func (tcc *testingCustomClaims) Validate(ctx context.Context) error { + return tcc.ReturnError +} + +func equalErrors(actual error, expected string) bool { + if actual == nil { + return expected == "" + } + return actual.Error() == expected +} + +func Test_Validate(t *testing.T) { + testCases := []struct { + name string + signatureAlgorithm string + token string + keyFuncReturnError error + customClaims CustomClaims + expectedError string + expectedContext jwt.Claims + }{ + { + name: "happy path", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.Rq8IxqeX7eA6GgYxlcHdPFVRNFFZc5rEI3MQTZZbK3I`, + expectedContext: &jwt.RegisteredClaims{Subject: "1234567890"}, + }, + { + // we want to test that when it expects RSA but we send + // HMAC encrypted with the server public key it will + // error + name: "errors on wrong algorithm", + signatureAlgorithm: "PS256", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.XbPfbIHMI6arZ3Y922BhjWgQzWXcXNrz0ogtVhfEd2o`, + expectedError: "could not parse the token: signing method HS256 is invalid", + }, + { + name: "errors on wrong token format errors", + expectedError: "could not parse the token: token contains an invalid number of segments", + }, + { + name: "errors when the key func errors", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.XbPfbIHMI6arZ3Y922BhjWgQzWXcXNrz0ogtVhfEd2o`, + keyFuncReturnError: errors.New("key func error message"), + expectedError: "could not parse the token: key func error message", + }, + { + name: "errors when signature is invalid", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.hDyICUnkCrwFJnkJHRSkwMZNSYZ9LI6z2EFJdtwFurA`, + expectedError: "could not parse the token: signature is invalid", + }, + { + name: "errors when custom claims errors", + token: `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZm9vIjoiYmFyIiwiaWF0IjoxNTE2MjM5MDIyfQ.DFTWyYib4-xFdMaEZFAYx5AKMPNS7Hhl4kcyjQVinYc`, + customClaims: &testingCustomClaims{ReturnError: errors.New("custom claims error message")}, + expectedError: "custom claims not validated: custom claims error message", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + var customClaimsFunc func() CustomClaims + if testCase.customClaims != nil { + customClaimsFunc = func() CustomClaims { return testCase.customClaims } + } + + v, _ := New(func(token *jwt.Token) (interface{}, error) { + return []byte("secret"), testCase.keyFuncReturnError + }, + testCase.signatureAlgorithm, + WithCustomClaims(customClaimsFunc), + ) + actualContext, err := v.ValidateToken(context.Background(), testCase.token) + if !equalErrors(err, testCase.expectedError) { + t.Fatalf("wanted err:\n%s\ngot:\n%+v\n", testCase.expectedError, err) + } + + if (testCase.expectedContext == nil && actualContext != nil) || (testCase.expectedContext != nil && actualContext == nil) { + t.Fatalf("wanted user context:\n%+v\ngot:\n%+v\n", testCase.expectedContext, actualContext) + } else if testCase.expectedContext != nil { + if diff := cmp.Diff(testCase.expectedContext, actualContext.(jwt.Claims)); diff != "" { + t.Errorf("user context mismatch (-want +got):\n%s", diff) + } + } + }) + } +} + +func Test_New(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + keyFunc := func(t *jwt.Token) (interface{}, error) { return nil, nil } + customClaims := func() CustomClaims { return nil } + + v, err := New(keyFunc, "HS256", WithCustomClaims(customClaims)) + + if !equalErrors(err, "") { + t.Fatalf("wanted err:\n%s\ngot:\n%+v\n", "", err) + } + + if v.keyFunc == nil { + t.Log("keyFunc was nil when it should not have been") + t.Fail() + } + + if v.signatureAlgorithm != "HS256" { + t.Logf("signatureAlgorithm was %q when it should have been %q", v.signatureAlgorithm, "HS256") + t.Fail() + } + + if v.customClaims == nil { + t.Log("customClaims was nil when it should not have been") + t.Fail() + } + }) + + t.Run("error on no keyFunc", func(t *testing.T) { + _, err := New(nil, "HS256") + + expectedErr := "keyFunc is required but was nil" + if !equalErrors(err, expectedErr) { + t.Fatalf("wanted err:\n%s\ngot:\n%+v\n", expectedErr, err) + } + }) +}