Skip to content
This repository was archived by the owner on Sep 21, 2022. It is now read-only.

Commit 132dc5c

Browse files
committed
Fixes #50
1 parent 4ed5c7f commit 132dc5c

File tree

2 files changed

+60
-36
lines changed

2 files changed

+60
-36
lines changed

middleware_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ func TestMiddlewarePlain(t *testing.T) {
6767
}
6868

6969
func TestMiddlewareContext(t *testing.T) {
70-
t.Skip()
7170
expect := "3"
7271
r := NewRouter()
7372
_ = r.Add(&Sample{}, contextMiddleware(0), contextMiddleware(1), contextMiddleware(2))
@@ -88,7 +87,6 @@ func TestMiddlewareContext(t *testing.T) {
8887
}
8988

9089
func TestMiddlewareMixed(t *testing.T) {
91-
t.Skip()
9290
expect := "6"
9391

9492
r := NewRouter()

routes.go

Lines changed: 60 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -279,61 +279,89 @@ func splitRoutes(routeStr string) (*route, error) {
279279
return nil, ErrRouteStringFormat
280280
}
281281

282+
type middlewareTyp int
283+
284+
const (
285+
plainMiddleware middlewareTyp = iota
286+
ctxMiddleware
287+
)
288+
289+
type middleware struct {
290+
typ middlewareTyp
291+
value interface{}
292+
}
293+
294+
func (m *middleware) ToHandler(ctx *Context) func(http.Handler) http.Handler {
295+
if m.typ == plainMiddleware {
296+
return m.value.(func(http.Handler) http.Handler)
297+
}
298+
fn := m.value.(func(*Context) error)
299+
return func(h http.Handler) http.Handler {
300+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
301+
err := fn(ctx)
302+
if err != nil {
303+
return
304+
}
305+
h.ServeHTTP(w, r)
306+
})
307+
}
308+
}
309+
282310
// add registers controller ctrl, using activeRoute. If middlewares are provided, utron uses
283311
// alice package to chain middlewares.
284312
func (r *Router) add(activeRoute *route, ctrl Controller, middlewares ...interface{}) error {
285-
chain := alice.New() // alice on chains
313+
var m []*middleware
286314
if len(middlewares) > 0 {
287-
var m []alice.Constructor
288315
for _, v := range middlewares {
289316
switch v.(type) {
290317
case func(http.Handler) http.Handler:
291-
m = append(m, v.(func(http.Handler) http.Handler))
318+
m = append(m, &middleware{
319+
typ: plainMiddleware,
320+
value: v,
321+
})
292322
case func(*Context) error:
293-
294-
// wrap func(*Context)error to a func(http.Handler)http.Handler
295-
//
296-
//TODO put this into a separate function?
297-
ctxMiddleware := func(h http.Handler) http.Handler {
298-
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
299-
ctx := NewContext(w, req)
300-
r.prepareContext(ctx)
301-
err := v.(func(*Context) error)(ctx)
302-
if err != nil {
303-
cerr := ctx.Commit()
304-
if cerr != nil {
305-
logThis.Errors(req.URL.Path, " ", cerr.Error())
306-
}
307-
return
308-
}
309-
h.ServeHTTP(ctx.Response(), ctx.Request())
310-
})
311-
}
312-
313-
m = append(m, ctxMiddleware)
323+
m = append(m, &middleware{
324+
typ: ctxMiddleware,
325+
value: v,
326+
})
314327

315328
default:
316329
return fmt.Errorf("unsupported middleware %v", v)
317330
}
318331
}
319-
chain = alice.New(m...)
320332
}
321-
322333
// register methods if any
323334
if len(activeRoute.methods) > 0 {
324335
r.HandleFunc(activeRoute.pattern, func(w http.ResponseWriter, req *http.Request) {
325-
chain.ThenFunc(r.wrapController(activeRoute.fn, ctrl)).ServeHTTP(w, req)
336+
ctx := NewContext(w, req)
337+
r.prepareContext(ctx)
338+
chain := chainMiddleware(ctx, m...)
339+
chain.ThenFunc(r.wrapController(ctx, activeRoute.fn, ctrl)).ServeHTTP(w, req)
326340
}).Methods(activeRoute.methods...)
327341
return nil
328342
}
329-
330343
r.HandleFunc(activeRoute.pattern, func(w http.ResponseWriter, req *http.Request) {
331-
chain.ThenFunc(r.wrapController(activeRoute.fn, ctrl)).ServeHTTP(w, req)
344+
ctx := NewContext(w, req)
345+
r.prepareContext(ctx)
346+
chain := chainMiddleware(ctx, m...)
347+
chain.ThenFunc(r.wrapController(ctx, activeRoute.fn, ctrl)).ServeHTTP(w, req)
332348
})
333349

334350
return nil
335351
}
336352

353+
func chainMiddleware(ctx *Context, wares ...*middleware) alice.Chain {
354+
if len(wares) > 0 {
355+
var m []alice.Constructor
356+
for _, v := range wares {
357+
m = append(m, v.ToHandler(ctx))
358+
}
359+
return alice.New(m...)
360+
}
361+
return alice.New()
362+
363+
}
364+
337365
// prepareContext sets view,config and model on the ctx.
338366
func (r *Router) prepareContext(ctx *Context) {
339367
if r.app != nil {
@@ -350,9 +378,7 @@ func (r *Router) prepareContext(ctx *Context) {
350378
}
351379

352380
// executes the method fn on Controller ctrl, it sets context.
353-
func (r *Router) handleController(w http.ResponseWriter, req *http.Request, fn string, ctrl Controller) {
354-
ctx := NewContext(w, req)
355-
r.prepareContext(ctx)
381+
func (r *Router) handleController(ctx *Context, fn string, ctrl Controller) {
356382
ctrl.New(ctx)
357383

358384
// execute the method
@@ -371,9 +397,9 @@ func (r *Router) handleController(w http.ResponseWriter, req *http.Request, fn s
371397
}
372398

373399
// wrapController wraps a controller ctrl with method fn, and returns http.HandleFunc
374-
func (r *Router) wrapController(fn string, ctrl Controller) func(http.ResponseWriter, *http.Request) {
400+
func (r *Router) wrapController(ctx *Context, fn string, ctrl Controller) func(http.ResponseWriter, *http.Request) {
375401
return func(w http.ResponseWriter, req *http.Request) {
376-
r.handleController(w, req, fn, ctrl)
402+
r.handleController(ctx, fn, ctrl)
377403
}
378404
}
379405

0 commit comments

Comments
 (0)