From 882c15e6462ce0630529d04c330e84942ebd059d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Goran=20Mari=C4=87?= <45515666+GocaMaric@users.noreply.github.com> Date: Wed, 18 Sep 2024 22:06:37 +0200 Subject: [PATCH] Update content_type.go (#880) --- middleware/content_type.go | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/middleware/content_type.go b/middleware/content_type.go index 023978fa..e61ff264 100644 --- a/middleware/content_type.go +++ b/middleware/content_type.go @@ -6,36 +6,32 @@ import ( ) // SetHeader is a convenience handler to set a response header key/value -func SetHeader(key, value string) func(next http.Handler) http.Handler { +func SetHeader(key, value string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set(key, value) next.ServeHTTP(w, r) - } - return http.HandlerFunc(fn) + }) } } // AllowContentType enforces a whitelist of request Content-Types otherwise responds // with a 415 Unsupported Media Type status. -func AllowContentType(contentTypes ...string) func(next http.Handler) http.Handler { +func AllowContentType(contentTypes ...string) func(http.Handler) http.Handler { allowedContentTypes := make(map[string]struct{}, len(contentTypes)) for _, ctype := range contentTypes { allowedContentTypes[strings.TrimSpace(strings.ToLower(ctype))] = struct{}{} } return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.ContentLength == 0 { - // skip check for empty content body + // Skip check for empty content body next.ServeHTTP(w, r) return } - s := strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))) - if i := strings.Index(s, ";"); i > -1 { - s = s[0:i] - } + s := strings.ToLower(strings.TrimSpace(strings.Split(r.Header.Get("Content-Type"), ";")[0])) if _, ok := allowedContentTypes[s]; ok { next.ServeHTTP(w, r) @@ -43,7 +39,7 @@ func AllowContentType(contentTypes ...string) func(next http.Handler) http.Handl } w.WriteHeader(http.StatusUnsupportedMediaType) - } - return http.HandlerFunc(fn) + }) } } +