Skip to content

Commit c06069c

Browse files
chore (refactoring): revisit middleware types
1 parent 5bdb3f8 commit c06069c

File tree

5 files changed

+51
-49
lines changed

5 files changed

+51
-49
lines changed

server/ctrl/webdav.go

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package ctrl
22

33
import (
4-
. "github.com/mickael-kerjean/filestash/server/common"
5-
"github.com/mickael-kerjean/filestash/server/model"
6-
"github.com/mickael-kerjean/net/webdav"
74
"net/http"
85
"path/filepath"
96
"strings"
7+
8+
. "github.com/mickael-kerjean/filestash/server/common"
9+
"github.com/mickael-kerjean/filestash/server/middleware"
10+
"github.com/mickael-kerjean/filestash/server/model"
11+
"github.com/mickael-kerjean/net/webdav"
1012
)
1113

1214
func WebdavHandler(ctx *App, res http.ResponseWriter, req *http.Request) {
@@ -53,8 +55,8 @@ func WebdavHandler(ctx *App, res http.ResponseWriter, req *http.Request) {
5355
* an imbecile and considering we can't even see the source code they are running, the best approach we
5456
* could go on is: "crap in, crap out" where useless request coming in are identified and answer appropriatly
5557
*/
56-
func WebdavBlacklist(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
57-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
58+
func WebdavBlacklist(fn middleware.HandlerFunc) middleware.HandlerFunc {
59+
return middleware.HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
5860
base := filepath.Base(req.URL.String())
5961

6062
if req.Method == "PUT" || req.Method == "MKCOL" {
@@ -125,5 +127,5 @@ func WebdavBlacklist(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx
125127
}
126128
}
127129
fn(ctx, res, req)
128-
}
130+
})
129131
}

server/middleware/context.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"strings"
1010
)
1111

12-
func BodyParser(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
12+
func BodyParser(fn HandlerFunc) HandlerFunc {
1313
extractBody := func(req *http.Request) (map[string]interface{}, error) {
1414
body := map[string]interface{}{}
1515
byt, err := ioutil.ReadAll(req.Body)
@@ -25,14 +25,14 @@ func BodyParser(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App
2525
return body, nil
2626
}
2727

28-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
28+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
2929
var err error
3030
if ctx.Body, err = extractBody(req); err != nil {
3131
SendErrorResult(res, ErrNotValid)
3232
return
3333
}
3434
fn(ctx, res, req)
35-
}
35+
})
3636
}
3737

3838
func GenerateRequestID(prefix string) string {

server/middleware/http.go

+21-21
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ import (
1010
"strings"
1111
)
1212

13-
func ApiHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
14-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
13+
func ApiHeaders(fn HandlerFunc) HandlerFunc {
14+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
1515
header := res.Header()
1616
header.Set("Content-Type", "application/json")
1717
header.Set("Cache-Control", "no-cache")
@@ -20,20 +20,20 @@ func ApiHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App
2020
header.Set("X-Request-ID", GenerateRequestID("API"))
2121
}
2222
fn(ctx, res, req)
23-
}
23+
})
2424
}
2525

26-
func StaticHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
27-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
26+
func StaticHeaders(fn HandlerFunc) HandlerFunc {
27+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
2828
header := res.Header()
2929
header.Set("Content-Type", GetMimeType(filepath.Ext(req.URL.Path)))
3030
header.Set("Cache-Control", "max-age=2592000")
3131
fn(ctx, res, req)
32-
}
32+
})
3333
}
3434

35-
func IndexHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
36-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
35+
func IndexHeaders(fn HandlerFunc) HandlerFunc {
36+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
3737
header := res.Header()
3838
header.Set("Content-Type", "text/html")
3939
header.Set("Cache-Control", "no-cache")
@@ -65,23 +65,23 @@ func IndexHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *A
6565
}
6666
// header.Set("Content-Security-Policy", cspHeader)
6767
fn(ctx, res, req)
68-
}
68+
})
6969
}
7070

71-
func SecureHeaders(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
72-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
71+
func SecureHeaders(fn HandlerFunc) HandlerFunc {
72+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
7373
header := res.Header()
7474
if Config.Get("general.force_ssl").Bool() {
7575
header.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload")
7676
}
7777
header.Set("X-Content-Type-Options", "nosniff")
7878
header.Set("X-XSS-Protection", "1; mode=block")
7979
fn(ctx, res, req)
80-
}
80+
})
8181
}
8282

83-
func SecureOrigin(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
84-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
83+
func SecureOrigin(fn HandlerFunc) HandlerFunc {
84+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
8585
if host := Config.Get("general.host").String(); host != "" {
8686
host = strings.TrimPrefix(host, "http://")
8787
host = strings.TrimPrefix(host, "https://")
@@ -105,11 +105,11 @@ func SecureOrigin(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *A
105105

106106
Log.Warning("Intrusion detection: %s - %s", RetrievePublicIp(req), req.URL.String())
107107
SendErrorResult(res, ErrNotAllowed)
108-
}
108+
})
109109
}
110110

111-
func WithPublicAPI(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
112-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
111+
func WithPublicAPI(fn HandlerFunc) HandlerFunc {
112+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
113113
apiKey := req.URL.Query().Get("key")
114114
if apiKey == "" {
115115
fn(ctx, res, req)
@@ -132,13 +132,13 @@ func WithPublicAPI(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *
132132
return
133133
}
134134
fn(ctx, res, req)
135-
}
135+
})
136136
}
137137

138138
var limiter = rate.NewLimiter(10, 1000)
139139

140-
func RateLimiter(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
141-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
140+
func RateLimiter(fn HandlerFunc) HandlerFunc {
141+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
142142
if limiter.Allow() == false {
143143
Log.Warning("middleware::http::ratelimit too many requests")
144144
SendErrorResult(
@@ -148,7 +148,7 @@ func RateLimiter(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *Ap
148148
return
149149
}
150150
fn(ctx, res, req)
151-
}
151+
})
152152
}
153153

154154
func EnableCors(req *http.Request, res http.ResponseWriter, host string) error {

server/middleware/index.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
)
88

99
type HandlerFunc func(*App, http.ResponseWriter, *http.Request)
10-
type Middleware func(func(*App, http.ResponseWriter, *http.Request)) func(*App, http.ResponseWriter, *http.Request)
10+
type Middleware func(HandlerFunc) HandlerFunc
1111

1212
func init() {
1313
Hooks.Register.Onload(func() {

server/middleware/session.go

+18-18
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@ import (
1414
"time"
1515
)
1616

17-
func LoggedInOnly(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
18-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
17+
func LoggedInOnly(fn HandlerFunc) HandlerFunc {
18+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
1919
if ctx.Backend == nil || ctx.Session == nil {
2020
SendErrorResult(res, ErrPermissionDenied)
2121
return
2222
}
2323
fn(ctx, res, req)
24-
}
24+
})
2525
}
2626

27-
func AdminOnly(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
28-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
27+
func AdminOnly(fn HandlerFunc) HandlerFunc {
28+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
2929
if admin := Config.Get("auth.admin").String(); admin != "" {
3030
c, err := req.Cookie(COOKIE_NAME_ADMIN)
3131
if err != nil {
@@ -47,11 +47,11 @@ func AdminOnly(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App,
4747
}
4848
}
4949
fn(ctx, res, req)
50-
}
50+
})
5151
}
5252

53-
func SessionStart(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
54-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
53+
func SessionStart(fn HandlerFunc) HandlerFunc {
54+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
5555
var err error
5656

5757
if ctx.Share, err = _extractShare(req); err != nil {
@@ -72,21 +72,21 @@ func SessionStart(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *A
7272
return
7373
}
7474
fn(ctx, res, req)
75-
}
75+
})
7676
}
7777

78-
func SessionTry(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
79-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
78+
func SessionTry(fn HandlerFunc) HandlerFunc {
79+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
8080
ctx.Share, _ = _extractShare(req)
8181
ctx.Authorization = _extractAuthorization(req)
8282
ctx.Session, _ = _extractSession(req, ctx)
8383
ctx.Backend, _ = _extractBackend(req, ctx)
8484
fn(ctx, res, req)
85-
}
85+
})
8686
}
8787

88-
func RedirectSharedLoginIfNeeded(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
89-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
88+
func RedirectSharedLoginIfNeeded(fn HandlerFunc) HandlerFunc {
89+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
9090
share_id := _extractShareId(req)
9191
if share_id == "" {
9292
if mux.Vars(req)["share"] == "private" {
@@ -103,11 +103,11 @@ func RedirectSharedLoginIfNeeded(fn func(*App, http.ResponseWriter, *http.Reques
103103
return
104104
}
105105
fn(ctx, res, req)
106-
}
106+
})
107107
}
108108

109-
func CanManageShare(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx *App, res http.ResponseWriter, req *http.Request) {
110-
return func(ctx *App, res http.ResponseWriter, req *http.Request) {
109+
func CanManageShare(fn HandlerFunc) HandlerFunc {
110+
return HandlerFunc(func(ctx *App, res http.ResponseWriter, req *http.Request) {
111111
share_id := mux.Vars(req)["share"]
112112
if share_id == "" {
113113
Log.Debug("middleware::session::share 'invalid share id'")
@@ -167,7 +167,7 @@ func CanManageShare(fn func(*App, http.ResponseWriter, *http.Request)) func(ctx
167167
}
168168
SendErrorResult(res, ErrPermissionDenied)
169169
return
170-
}
170+
})
171171
}
172172

173173
func _extractAuthorization(req *http.Request) (token string) {

0 commit comments

Comments
 (0)