Skip to content

Commit

Permalink
Add WithNotFoundHandler to ServeMux
Browse files Browse the repository at this point in the history
  • Loading branch information
roeldev committed Apr 25, 2024
1 parent 39cfb48 commit cea0b3a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
30 changes: 29 additions & 1 deletion router.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ type serveMux = http.ServeMux
// implements the [Router] interface on top of that.
// See [http.ServeMux] for additional information about pattern syntax,
// compatibility etc.
type ServeMux struct{ *serveMux }
type ServeMux struct {
*serveMux
notFound http.Handler
}

// NewServeMux creates a new [ServeMux] and is ready to be used.
func NewServeMux() *ServeMux {
Expand Down Expand Up @@ -99,6 +102,31 @@ func (mux *ServeMux) HandleRoute(route Route) {
mux.serveMux.Handle(pattern, route)
}

// WithNotFoundHandler sets a [http.Handler] which is called when there is no
// matching pattern. If not set, [ServeMux] will use the internal
// [http.ServeMux]'s default not found handler, which is [http.NotFound].
func (mux *ServeMux) WithNotFoundHandler(h http.Handler) *ServeMux {
mux.notFound = h
return mux
}

func (mux *ServeMux) ServeHTTP(wri http.ResponseWriter, req *http.Request) {
if req.RequestURI == "*" {
if req.ProtoAtLeast(1, 1) {
wri.Header().Set("Connection", "close")
}
wri.WriteHeader(http.StatusBadRequest)
return
}

h, pattern := mux.serveMux.Handler(req)
if pattern != "" || mux.notFound == nil {
h.ServeHTTP(wri, req)
} else {
mux.notFound.ServeHTTP(wri, req)
}
}

func (mux *ServeMux) apply(srv *Server) error {
srv.Handler = mux
return nil
Expand Down
30 changes: 25 additions & 5 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,36 @@ func TestRoute_ServeHTTP(t *testing.T) {
}
}

func TestRouter_HandleRoute(t *testing.T) {
router := NewServeMux()
router.HandleRoute(Route{
func TestServeMux_HandleRoute(t *testing.T) {
mux := NewServeMux()
mux.HandleRoute(Route{
Pattern: "/",
Handler: http.HandlerFunc(func(res http.ResponseWriter, _q *http.Request) {
Handler: http.HandlerFunc(func(res http.ResponseWriter, _ *http.Request) {
res.WriteHeader(http.StatusOK)
}),
})

rec := httptest.NewRecorder()
router.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
mux.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
assert.Equal(t, http.StatusOK, rec.Code)
}

func TestServeMux_ServeHTTP(t *testing.T) {
t.Run("default not found", func(t *testing.T) {
rec := httptest.NewRecorder()
NewServeMux().ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
assert.Equal(t, http.StatusNotFound, rec.Code)
})

t.Run("custom not found", func(t *testing.T) {
const want = "my custom not found message"
mux := NewServeMux().
WithNotFoundHandler(http.HandlerFunc(func(wri http.ResponseWriter, _ *http.Request) {
_, _ = wri.Write([]byte(want))
}))

rec := httptest.NewRecorder()
mux.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
assert.Equal(t, want, rec.Body.String())
})
}

0 comments on commit cea0b3a

Please sign in to comment.