diff --git a/context.go b/context.go index e78a2385..88f8e221 100644 --- a/context.go +++ b/context.go @@ -76,6 +76,7 @@ type Context struct { // methodNotAllowed hint methodNotAllowed bool + methodsAllowed []methodTyp // allowed methods in case of a 405 } // Reset a routing context to its initial state. diff --git a/mux.go b/mux.go index 0d1caa6e..977aa52d 100644 --- a/mux.go +++ b/mux.go @@ -378,11 +378,11 @@ func (mx *Mux) NotFoundHandler() http.HandlerFunc { // MethodNotAllowedHandler returns the default Mux 405 responder whenever // a method cannot be resolved for a route. -func (mx *Mux) MethodNotAllowedHandler() http.HandlerFunc { +func (mx *Mux) MethodNotAllowedHandler(methodsAllowed ...methodTyp) http.HandlerFunc { if mx.methodNotAllowedHandler != nil { return mx.methodNotAllowedHandler } - return methodNotAllowedHandler + return methodNotAllowedHandler(methodsAllowed...) } // handle registers a http.Handler in the routing tree for a particular http method @@ -445,7 +445,7 @@ func (mx *Mux) routeHTTP(w http.ResponseWriter, r *http.Request) { return } if rctx.methodNotAllowed { - mx.MethodNotAllowedHandler().ServeHTTP(w, r) + mx.MethodNotAllowedHandler(rctx.methodsAllowed...).ServeHTTP(w, r) } else { mx.NotFoundHandler().ServeHTTP(w, r) } @@ -480,8 +480,14 @@ func (mx *Mux) updateRouteHandler() { } // methodNotAllowedHandler is a helper function to respond with a 405, -// method not allowed. -func methodNotAllowedHandler(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(405) - w.Write(nil) +// method not allowed. It sets the Allow header with the list of allowed +// methods for the route. +func methodNotAllowedHandler(methodsAllowed ...methodTyp) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + for _, m := range methodsAllowed { + w.Header().Add("Allow", reverseMethodMap[m]) + } + w.WriteHeader(405) + w.Write(nil) + } } diff --git a/mux_test.go b/mux_test.go index 68fc94c0..0f8f8995 100644 --- a/mux_test.go +++ b/mux_test.go @@ -392,6 +392,43 @@ func TestMuxNestedNotFound(t *testing.T) { } } +func TestMethodNotAllowed(t *testing.T) { + r := NewRouter() + + r.Get("/hi", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi, get")) + }) + + r.Head("/hi", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hi, head")) + }) + + ts := httptest.NewServer(r) + defer ts.Close() + + t.Run("Registered Method", func(t *testing.T) { + resp, _ := testRequest(t, ts, "GET", "/hi", nil) + if resp.StatusCode != 200 { + t.Fatal(resp.Status) + } + if resp.Header.Values("Allow") != nil { + t.Fatal("allow should be empty when method is registered") + } + }) + + t.Run("Unregistered Method", func(t *testing.T) { + resp, _ := testRequest(t, ts, "POST", "/hi", nil) + if resp.StatusCode != 405 { + t.Fatal(resp.Status) + } + allowedMethods := resp.Header.Values("Allow") + if len(allowedMethods) != 2 || allowedMethods[0] != "GET" || allowedMethods[1] != "HEAD" { + t.Fatal("Allow header should contain 2 headers: GET, HEAD. Received: ", allowedMethods) + + } + }) +} + func TestMuxNestedMethodNotAllowed(t *testing.T) { r := NewRouter() r.Get("/root", func(w http.ResponseWriter, r *http.Request) { @@ -1771,6 +1808,7 @@ func BenchmarkMux(b *testing.B) { mx := NewRouter() mx.Get("/", h1) mx.Get("/hi", h2) + mx.Post("/hi-post", h2) // used to benchmark 405 responses mx.Get("/sup/{id}/and/{this}", h3) mx.Get("/sup/{id}/{bar:foo}/{this}", h3) @@ -1787,6 +1825,7 @@ func BenchmarkMux(b *testing.B) { routes := []string{ "/", "/hi", + "/hi-post", "/sup/123/and/this", "/sup/123/foo/this", "/sharing/z/aBc", // subrouter-1 diff --git a/tree.go b/tree.go index 4189b522..c7d3bc57 100644 --- a/tree.go +++ b/tree.go @@ -43,6 +43,18 @@ var methodMap = map[string]methodTyp{ http.MethodTrace: mTRACE, } +var reverseMethodMap = map[methodTyp]string{ + mCONNECT: http.MethodConnect, + mDELETE: http.MethodDelete, + mGET: http.MethodGet, + mHEAD: http.MethodHead, + mOPTIONS: http.MethodOptions, + mPATCH: http.MethodPatch, + mPOST: http.MethodPost, + mPUT: http.MethodPut, + mTRACE: http.MethodTrace, +} + // RegisterMethod adds support for custom HTTP method handlers, available // via Router#Method and Router#MethodFunc func RegisterMethod(method string) { @@ -454,6 +466,13 @@ func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node { return xn } + for endpoints := range xn.endpoints { + if endpoints == mALL || endpoints == mSTUB { + continue + } + rctx.methodsAllowed = append(rctx.methodsAllowed, endpoints) + } + // flag that the routing context found a route, but not a corresponding // supported method rctx.methodNotAllowed = true @@ -493,6 +512,13 @@ func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node { return xn } + for endpoints := range xn.endpoints { + if endpoints == mALL || endpoints == mSTUB { + continue + } + rctx.methodsAllowed = append(rctx.methodsAllowed, endpoints) + } + // flag that the routing context found a route, but not a corresponding // supported method rctx.methodNotAllowed = true