From d9d9e34d23fb42b42b4dd66f64bb2c487e49355c Mon Sep 17 00:00:00 2001 From: Tatsuro Alpert Date: Tue, 1 May 2018 16:30:00 -0400 Subject: [PATCH 1/3] basic auth middleware --- middleware_auth.go | 103 +++++++++++++++++++++++++++ middleware_auth_test.go | 151 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 254 insertions(+) create mode 100644 middleware_auth.go create mode 100644 middleware_auth_test.go diff --git a/middleware_auth.go b/middleware_auth.go new file mode 100644 index 0000000..78ea9fd --- /dev/null +++ b/middleware_auth.go @@ -0,0 +1,103 @@ +package rye + +import ( + "context" + "encoding/base64" + "errors" + "net/http" + "strings" +) + +/* +NewMiddlewareAuth creates a new middleware to extract the Authorization header +from a request and validate it. It accepts a func of type AuthFunc which is +used to do the credential validation. +An AuthFunc for Basic auth is provided here. + +Example usage: + + routes.Handle("/some/route", myMWHandler.Handle( + []rye.Handler{ + rye.NewMiddlewareAuth(rye.NewBasicAuthFunc(map[string]string{ + "user1": "my_password", + })), + yourHandler, + })).Methods("POST") +*/ + +type AuthFunc func(context.Context, string) *Response + +func NewMiddlewareAuth(authFunc AuthFunc) func(rw http.ResponseWriter, req *http.Request) *Response { + return func(rw http.ResponseWriter, r *http.Request) *Response { + auth := r.Header.Get("Authorization") + if auth == "" { + return &Response{ + Err: errors.New("unauthorized: no authentication provided"), + StatusCode: http.StatusUnauthorized, + } + } + + return authFunc(r.Context(), auth) + } +} + +/*********** + Basic Auth +***********/ + +func NewBasicAuthFunc(userPass map[string]string) AuthFunc { + return basicAuth(userPass).authenticate +} + +type basicAuth map[string]string + +const AUTH_USERNAME_KEY = "request-username" + +// basicAuth.authenticate meets the AuthFunc type +func (b basicAuth) authenticate(ctx context.Context, auth string) *Response { + errResp := &Response{ + Err: errors.New("unauthorized: invalid authentication provided"), + StatusCode: http.StatusUnauthorized, + } + + // parse the Authorization header + u, p, ok := parseBasicAuth(auth) + if !ok { + return errResp + } + + // get the password + pass, ok := b[u] + if !ok { + return errResp + } + + // compare the password + if pass != p { + return errResp + } + + // add username to the context + return &Response{ + Context: context.WithValue(ctx, AUTH_USERNAME_KEY, u), + } +} + +// parseBasicAuth parses an HTTP Basic Authentication string. +// taken from net/http/request.go +func parseBasicAuth(auth string) (username, password string, ok bool) { + const prefix = "Basic " + if !strings.HasPrefix(auth, prefix) { + return + } + c, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) + if err != nil { + return + } + cs := string(c) + s := strings.IndexByte(cs, ':') + if s < 0 { + return + } + return cs[:s], cs[s+1:], true +} diff --git a/middleware_auth_test.go b/middleware_auth_test.go new file mode 100644 index 0000000..2fbc320 --- /dev/null +++ b/middleware_auth_test.go @@ -0,0 +1,151 @@ +package rye + +import ( + "net/http" + "net/http/httptest" + + "context" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +const AUTH_HEADER_NAME = "Authorization" + +var _ = Describe("Auth Middleware", func() { + var ( + request *http.Request + response *httptest.ResponseRecorder + + testHandler func(http.ResponseWriter, *http.Request) *Response + ) + + BeforeEach(func() { + response = httptest.NewRecorder() + }) + + Context("auth", func() { + var ( + fakeAuth *recorder + ) + + BeforeEach(func() { + fakeAuth = &recorder{} + + testHandler = NewMiddlewareAuth(fakeAuth.authFunc) + request = &http.Request{ + Header: map[string][]string{}, + } + }) + + It("passes the header to the auth func", func() { + testAuth := "foobar" + request.Header.Add(AUTH_HEADER_NAME, testAuth) + resp := testHandler(response, request) + + Expect(resp).To(BeNil()) + Expect(fakeAuth.header).To(Equal(testAuth)) + }) + + Context("when no header is found", func() { + It("errors", func() { + resp := testHandler(response, request) + + Expect(resp).ToNot(BeNil()) + Expect(resp.Err).ToNot(BeNil()) + Expect(resp.Err.Error()).To(ContainSubstring("no authentication")) + }) + }) + }) + + Context("Basic Auth", func() { + var ( + username = "user1" + pass = "mypass" + ) + + BeforeEach(func() { + testHandler = NewMiddlewareAuth(NewBasicAuthFunc(map[string]string{ + username: pass, + })) + + request = &http.Request{ + Header: map[string][]string{}, + } + }) + + It("validates the password", func() { + request.SetBasicAuth(username, pass) + resp := testHandler(response, request) + + Expect(resp.Err).To(BeNil()) + }) + + It("adds the username to context", func() { + request.SetBasicAuth(username, pass) + resp := testHandler(response, request) + + Expect(resp.Err).To(BeNil()) + + ctxUname := resp.Context.Value(AUTH_USERNAME_KEY) + uname, ok := ctxUname.(string) + Expect(ok).To(BeTrue()) + Expect(uname).To(Equal(username)) + }) + + It("preserves the request context", func() { + + }) + + It("errors if username unknown", func() { + request.SetBasicAuth("noname", pass) + resp := testHandler(response, request) + + Expect(resp.Err).ToNot(BeNil()) + Expect(resp.Err.Error()).To(ContainSubstring("invalid auth")) + }) + + It("errors if password wrong", func() { + request.SetBasicAuth(username, "wrong") + resp := testHandler(response, request) + + Expect(resp.Err).ToNot(BeNil()) + Expect(resp.Err.Error()).To(ContainSubstring("invalid auth")) + }) + + Context("parseBasicAuth", func() { + It("errors if header not basic", func() { + request.Header.Add(AUTH_HEADER_NAME, "wrong") + resp := testHandler(response, request) + + Expect(resp.Err).ToNot(BeNil()) + Expect(resp.Err.Error()).To(ContainSubstring("invalid auth")) + }) + + It("errors if header not base64", func() { + request.Header.Add(AUTH_HEADER_NAME, "Basic ------") + resp := testHandler(response, request) + + Expect(resp.Err).ToNot(BeNil()) + Expect(resp.Err.Error()).To(ContainSubstring("invalid auth")) + }) + + It("errors if header wrong format", func() { + request.Header.Add(AUTH_HEADER_NAME, "Basic YXNkZgo=") // asdf no `:` + resp := testHandler(response, request) + + Expect(resp.Err).ToNot(BeNil()) + Expect(resp.Err.Error()).To(ContainSubstring("invalid auth")) + }) + }) + }) +}) + +type recorder struct { + header string +} + +func (r *recorder) authFunc(ctx context.Context, s string) *Response { + r.header = s + return nil +} From 52e86fb74bfa95e84557f06f56d37aafe88cfb03 Mon Sep 17 00:00:00 2001 From: Tatsuro Alpert Date: Wed, 2 May 2018 17:36:23 -0400 Subject: [PATCH 2/3] new version of jwt auth middleware --- middleware_auth.go | 56 +++++++++++++++++++++++++++++++++++++++--- middleware_jwt.go | 52 ++++++--------------------------------- middleware_jwt_test.go | 19 +++++++++++++- 3 files changed, 78 insertions(+), 49 deletions(-) diff --git a/middleware_auth.go b/middleware_auth.go index 78ea9fd..c954fe7 100644 --- a/middleware_auth.go +++ b/middleware_auth.go @@ -4,15 +4,18 @@ import ( "context" "encoding/base64" "errors" + "fmt" "net/http" "strings" + + jwt "github.com/dgrijalva/jwt-go" ) /* NewMiddlewareAuth creates a new middleware to extract the Authorization header from a request and validate it. It accepts a func of type AuthFunc which is used to do the credential validation. -An AuthFunc for Basic auth is provided here. +An AuthFuncs for Basic auth and JWT are provided here. Example usage: @@ -83,14 +86,15 @@ func (b basicAuth) authenticate(ctx context.Context, auth string) *Response { } } +const basicPrefix = "Basic " + // parseBasicAuth parses an HTTP Basic Authentication string. // taken from net/http/request.go func parseBasicAuth(auth string) (username, password string, ok bool) { - const prefix = "Basic " - if !strings.HasPrefix(auth, prefix) { + if !strings.HasPrefix(auth, basicPrefix) { return } - c, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) + c, err := base64.StdEncoding.DecodeString(auth[len(basicPrefix):]) if err != nil { return } @@ -101,3 +105,47 @@ func parseBasicAuth(auth string) (username, password string, ok bool) { } return cs[:s], cs[s+1:], true } + +/**** + JWT +****/ + +type jwtAuth struct { + secret string +} + +func NewJWTAuthFunc(secret string) AuthFunc { + j := &jwtAuth{secret: secret} + return j.authenticate +} + +const bearerPrefix = "Bearer " + +func (j *jwtAuth) authenticate(ctx context.Context, auth string) *Response { + // Remove 'Bearer' prefix + if !strings.HasPrefix(auth, bearerPrefix) && !strings.HasPrefix(auth, strings.ToLower(bearerPrefix)) { + return &Response{ + Err: errors.New("unauthorized: invalid authentication provided"), + StatusCode: http.StatusUnauthorized, + } + } + + token := auth[len(bearerPrefix):] + + _, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("Unexpected signing method") + } + return []byte(j.secret), nil + }) + if err != nil { + return &Response{ + Err: err, + StatusCode: http.StatusUnauthorized, + } + } + + return &Response{ + Context: context.WithValue(ctx, CONTEXT_JWT, token), + } +} diff --git a/middleware_jwt.go b/middleware_jwt.go index 7487c3e..6fae309 100644 --- a/middleware_jwt.go +++ b/middleware_jwt.go @@ -1,13 +1,6 @@ package rye -import ( - "context" - "fmt" - "net/http" - "regexp" - - "github.com/dgrijalva/jwt-go" -) +import "net/http" const ( CONTEXT_JWT = "rye-middlewarejwt-jwt" @@ -19,6 +12,12 @@ type jwtVerify struct { } /* +This middleware is deprecated. Use NewMiddlewareAuth with NewJWTAuthFunc instead. + +This remains here as a shim for backwards compatibility. + +--------------------------------------------------------------------------- + This middleware provides JWT verification functionality You can use this middleware by specifying `rye.NewMiddlewareJWT(shared_secret)` @@ -53,40 +52,5 @@ Access to that is simple (using the CONTEXT_JWT constant as a key) */ func NewMiddlewareJWT(secret string) func(rw http.ResponseWriter, req *http.Request) *Response { - j := &jwtVerify{secret: secret} - return j.handle -} - -func (j *jwtVerify) handle(rw http.ResponseWriter, req *http.Request) *Response { - - tokenHeader := req.Header.Get("Authorization") - - if tokenHeader == "" { - return &Response{ - Err: fmt.Errorf("JWT token must be passed with Authorization header"), - StatusCode: 400, - } - } - - // Remove 'Bearer' prefix - p, _ := regexp.Compile(`(?i)bearer\s+`) - j.token = p.ReplaceAllString(tokenHeader, "") - - _, err := jwt.Parse(j.token, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("Unexpected signing method") - } - return []byte(j.secret), nil - }) - - if err != nil { - return &Response{ - Err: err, - StatusCode: 401, - } - } - - ctx := context.WithValue(req.Context(), CONTEXT_JWT, j.token) - - return &Response{Context: ctx} + return NewMiddlewareAuth(NewJWTAuthFunc(secret)) } diff --git a/middleware_jwt_test.go b/middleware_jwt_test.go index b0bc0f7..b6d112a 100644 --- a/middleware_jwt_test.go +++ b/middleware_jwt_test.go @@ -35,13 +35,21 @@ var _ = Describe("JWT Middleware", func() { Expect(resp.Context).ToNot(BeNil()) Expect(resp.Context.Value(CONTEXT_JWT)).To(Equal(hs256_jwt)) }) + + It("lower case bearer is also accepted", func() { + request.Header.Add("Authorization", fmt.Sprintf("bearer %s", hs256_jwt)) + resp := NewMiddlewareJWT(shared_secret)(response, request) + Expect(resp).ToNot(BeNil()) + Expect(resp.Context).ToNot(BeNil()) + Expect(resp.Context.Value(CONTEXT_JWT)).To(Equal(hs256_jwt)) + }) }) Context("when no token is passed", func() { It("should return an error", func() { resp := NewMiddlewareJWT(shared_secret)(response, request) Expect(resp).ToNot(BeNil()) - Expect(resp.Error()).To(ContainSubstring("JWT token must be passed")) + Expect(resp.Error()).To(ContainSubstring("no authentication provided")) }) }) @@ -62,5 +70,14 @@ var _ = Describe("JWT Middleware", func() { Expect(resp.Error()).To(ContainSubstring("signing method")) }) }) + + Context("token with wrong header format", func() { + It("should return an error", func() { + request.Header.Add("Authorization", fmt.Sprintf("foo %s", rs256_jwt)) + resp := NewMiddlewareJWT(shared_secret)(response, request) + Expect(resp).ToNot(BeNil()) + Expect(resp.Error()).To(ContainSubstring("invalid authentication")) + }) + }) }) }) From cadc9676f0bf9185edbc2d9e8f456879dc6f0130 Mon Sep 17 00:00:00 2001 From: Tatsuro Alpert Date: Wed, 2 May 2018 17:55:18 -0400 Subject: [PATCH 3/3] docs and example --- README.md | 4 ++-- example/rye_example.go | 15 +++++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index b0b88d3..d8dcb00 100644 --- a/README.md +++ b/README.md @@ -226,7 +226,7 @@ the example using Gorilla: | [Access Token](middleware_accesstoken.go) | Provide Access Token validation | | [CIDR](middleware_cidr.go) | Provide request IP whitelisting | | [CORS](middleware_cors.go) | Provide CORS functionality for routes | -| [JWT](middleware_jwt.go) | Provide JWT validation | +| [Auth](middleware_auth.go) | Provide Authorization header validation (basic auth, JWT) | | [Route Logger](middleware_routelogger.go) | Provide basic logging for a specific route | | [Static File](middleware_static_file.go) | Provides serving a single file | | [Static Filesystem](middleware_static_filesystem.go) | Provides serving a single file | @@ -234,7 +234,7 @@ the example using Gorilla: ### A Note on the JWT Middleware -The [JWT Middleware](middleware_jwt.go) pushes the JWT token onto the Context for use by other middlewares in the chain. This is a convenience that allows any part of your middleware chain quick access to the JWT. Example usage might include a middleware that needs access to your user id or email address stored in the JWT. To access this `Context` variable, the code is very simple: +The [JWT Middleware](middleware_auth.go) pushes the JWT token onto the Context for use by other middlewares in the chain. This is a convenience that allows any part of your middleware chain quick access to the JWT. Example usage might include a middleware that needs access to your user id or email address stored in the JWT. To access this `Context` variable, the code is very simple: ```go func getJWTfromContext(rw http.ResponseWriter, r *http.Request) *rye.Response { // Retrieving the value is easy! diff --git a/example/rye_example.go b/example/rye_example.go index a08b8a0..2a00927 100644 --- a/example/rye_example.go +++ b/example/rye_example.go @@ -7,9 +7,9 @@ import ( "net/http" "github.com/InVisionApp/rye" - log "github.com/sirupsen/logrus" "github.com/cactus/go-statsd-client/statsd" "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" ) func main() { @@ -41,11 +41,22 @@ func main() { homeHandler, })).Methods("GET", "OPTIONS") + // If you perform an `curl -i http://localhost:8181/jwt \ + // -H "Authorization: Basic dXNlcjE6cGFzczEK" + // you will see that we are allowed through to the handler, if the header is changed, you will get a 401 + routes.Handle("/basic-auth", middlewareHandler.Handle([]rye.Handler{ + rye.NewMiddlewareAuth(rye.NewBasicAuthFunc(map[string]string{ + "user1": "pass1", + "user2": "pass2", + })), + getJwtFromContextHandler, + })).Methods("GET") + // If you perform an `curl -i http://localhost:8181/jwt \ // -H "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.TJVA95OrM7E2cBab30RMHrHDcEfxjoYZgeFONFh7HgQ" // you will see that we are allowed through to the handler, if the sample token is changed, we will get a 401 routes.Handle("/jwt", middlewareHandler.Handle([]rye.Handler{ - rye.NewMiddlewareJWT("secret"), + rye.NewMiddlewareAuth(rye.NewJWTAuthFunc("secret")), getJwtFromContextHandler, })).Methods("GET")