forked from go-goyave/goyave
-
Notifications
You must be signed in to change notification settings - Fork 0
/
middleware.go
240 lines (213 loc) · 7.05 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
package goyave
import (
"net/http"
"strings"
"gorm.io/gorm"
"goyave.dev/goyave/v5/cors"
"goyave.dev/goyave/v5/util/errors"
"goyave.dev/goyave/v5/validation"
)
// Middleware are special handlers executed in a stack above the controller handler.
// They allow to inspect and filter requests, transform responses or provide additional
// information to the next handlers in the stack.
// Example uses are authentication, authorization, logging, panic recovery, CORS,
// validation, gzip compression.
type Middleware interface {
Composable
Handle(next Handler) Handler
}
type middlewareHolder struct {
middleware []Middleware
}
func (h *middlewareHolder) applyMiddleware(handler Handler) Handler {
for i := len(h.middleware) - 1; i >= 0; i-- {
handler = h.middleware[i].Handle(handler)
}
return handler
}
// GetMiddleware returns a copy of the middleware applied on this holder.
func (h *middlewareHolder) GetMiddleware() []Middleware {
return append(make([]Middleware, 0, len(h.middleware)), h.middleware...)
}
func findMiddleware[T Middleware](m []Middleware) T {
for _, middleware := range m {
if m, ok := middleware.(T); ok {
return m
}
}
var zero T
return zero
}
func hasMiddleware[T Middleware](m []Middleware) bool {
for _, middleware := range m {
if _, ok := middleware.(T); ok {
return true
}
}
return false
}
// routeHasMiddleware returns true if the given route or any of its
// parents has a middleware of the T type.
func routeHasMiddleware[T Middleware](route *Route) bool {
return hasMiddleware[T](route.middleware)
}
// routerHasMiddleware returns true if the given route or any of its
// parents has a middleware of the T type. Also returns true if the middleware
// is present as global middleware.
func routerHasMiddleware[T Middleware](router *Router) bool {
return hasMiddleware[T](router.globalMiddleware.middleware) || hasMiddleware[T](router.middleware) || (router.parent != nil && routerHasMiddleware[T](router.parent))
}
// recoveryMiddleware is a middleware that recovers from panic and sends a 500 error code.
// If debugging is enabled in the config and the default status handler for the 500 status code
// had not been changed, the error is also written in the response.
type recoveryMiddleware struct {
Component
}
func (m *recoveryMiddleware) Handle(next Handler) Handler {
return func(response *Response, request *Request) {
panicked := true
defer func() {
if err := recover(); err != nil || panicked {
e := errors.NewSkip(err, 4).(*errors.Error) // Skipped: runtime.Callers, NewSkip, this func, runtime.panic
m.Logger().Error(e)
response.err = e
response.status = http.StatusInternalServerError // Force status override
}
}()
next(response, request)
panicked = false
}
}
// languageMiddleware is a middleware that sets the language of a request.
//
// Uses the "Accept-Language" header to determine which language to use. If
// the header is not set or the language is not available, uses the default
// language as fallback.
//
// If "*" is provided, the default language will be used.
// If multiple languages are given, the first available language will be used,
// and if none are available, the default language will be used.
// If no variant is given (for example "en"), the first available variant will be used.
// For example, if "en-US" and "en-UK" are available and the request accepts "en",
// "en-US" will be used.
type languageMiddleware struct {
Component
}
func (m *languageMiddleware) Handle(next Handler) Handler {
return func(response *Response, request *Request) {
if header := request.Header().Get("Accept-Language"); len(header) > 0 {
request.Lang = m.Lang().DetectLanguage(header)
} else {
request.Lang = m.Lang().GetDefault()
}
next(response, request)
}
}
// validateRequestMiddleware is a middleware that validates the request.
// If validation is not rules are not met, sets the response status to 422 Unprocessable Entity
// or 400 Bad Request and the response error (which can be retrieved with `GetError()`) to the
// `validation.Errors` returned by the validator.
// This data can then be used in a status handler.
// This middleware requires the parse middleware.
type validateRequestMiddleware struct {
Component
BodyRules RuleSetFunc
QueryRules RuleSetFunc
}
func (m *validateRequestMiddleware) Handle(next Handler) Handler {
return func(response *Response, r *Request) {
extra := map[any]any{
validation.ExtraRequest{}: r,
}
contentType := r.Header().Get("Content-Type")
var db *gorm.DB
if m.Config().GetString("database.connection") != "none" {
db = m.DB().WithContext(r.Context())
}
var errsBag *validation.Errors
var queryErrsBag *validation.Errors
var errors []error
if m.QueryRules != nil {
opt := &validation.Options{
Data: r.Query,
Rules: m.QueryRules(r).AsRules(),
ConvertSingleValueArrays: true,
Language: r.Lang,
DB: db,
Config: m.Config(),
Logger: m.Logger(),
Extra: extra,
}
r.Extra[ExtraQueryValidationRules{}] = opt.Rules
var err []error
queryErrsBag, err = validation.Validate(opt)
if queryErrsBag != nil {
r.Extra[ExtraQueryValidationError{}] = queryErrsBag
}
if err != nil {
errors = append(errors, err...)
}
}
if m.BodyRules != nil {
opt := &validation.Options{
Data: r.Data,
Rules: m.BodyRules(r).AsRules(),
ConvertSingleValueArrays: !strings.HasPrefix(contentType, "application/json"),
Language: r.Lang,
DB: db,
Config: m.Config(),
Logger: m.Logger(),
Extra: extra,
}
r.Extra[ExtraBodyValidationRules{}] = opt.Rules
var err []error
errsBag, err = validation.Validate(opt)
if errsBag != nil {
r.Extra[ExtraValidationError{}] = errsBag
}
if err != nil {
errors = append(errors, err...)
}
r.Data = opt.Data
}
if len(errors) != 0 {
response.Error(errors)
return
}
if errsBag != nil || queryErrsBag != nil {
response.Status(http.StatusUnprocessableEntity)
return
}
next(response, r)
}
}
type corsMiddleware struct {
Component
}
func (m *corsMiddleware) Handle(next Handler) Handler {
return func(response *Response, request *Request) {
o, ok := request.Route.LookupMeta(MetaCORS)
if !ok || o == nil || o == (*cors.Options)(nil) {
next(response, request)
return
}
options := o.(*cors.Options)
headers := response.Header()
requestHeaders := request.Header()
if request.Method() == http.MethodOptions && requestHeaders.Get("Access-Control-Request-Method") == "" {
response.Status(http.StatusBadRequest)
return
}
options.ConfigureCommon(headers, requestHeaders)
if request.Method() == http.MethodOptions {
options.HandlePreflight(headers, requestHeaders)
if options.OptionsPassthrough {
next(response, request)
} else {
response.WriteHeader(http.StatusNoContent)
}
} else {
next(response, request)
}
}
}