Skip to content

Commit edb1b3f

Browse files
Merge pull request #72 from Workiva/cors_whitelist
Require origin whitelist for CORS middleware
2 parents 07477d1 + b4b3567 commit edb1b3f

File tree

5 files changed

+139
-34
lines changed

5 files changed

+139
-34
lines changed

rest/api.go

+13-4
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,17 @@ func NewConfiguration() *Configuration {
8989
}
9090
}
9191

92+
// MiddlewareError is returned by Middleware to indicate that a request should
93+
// not be served.
94+
type MiddlewareError struct {
95+
Code int
96+
Response []byte
97+
}
98+
9299
// Middleware can be passed in to API#Start and API#StartTLS and will be
93-
// invoked on every request to a route handled by the API. Returns true if the
94-
// request should be terminated, false if it should continue.
95-
type Middleware func(w http.ResponseWriter, r *http.Request) bool
100+
// invoked on every request to a route handled by the API. Returns a
101+
// MiddlewareError if the request should be terminated.
102+
type Middleware func(w http.ResponseWriter, r *http.Request) *MiddlewareError
96103

97104
// middlewareProxy proxies an http.Handler by invoking middleware before
98105
// passing the request to the Handler. It implements the http.Handler
@@ -106,7 +113,9 @@ type middlewareProxy struct {
106113
// proxied http.Handler.
107114
func (m *middlewareProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
108115
for _, middleware := range m.middleware {
109-
if middleware(w, r) {
116+
if err := middleware(w, r); err != nil {
117+
w.WriteHeader(err.Code)
118+
w.Write(err.Response)
110119
return
111120
}
112121
}

rest/api_test.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -1082,14 +1082,14 @@ func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
10821082
}
10831083

10841084
// Ensures that middlewareProxy invokes middleware and doesn't delegate to the
1085-
// wrapped http.Handler if any middleware return true.
1085+
// wrapped http.Handler if any middleware return error.
10861086
func TestMiddlewareProxyTerminate(t *testing.T) {
10871087
assert := assert.New(t)
10881088
handler := &httpHandler{}
10891089
called := false
1090-
middleware := func(w http.ResponseWriter, r *http.Request) bool {
1090+
middleware := func(w http.ResponseWriter, r *http.Request) *MiddlewareError {
10911091
called = true
1092-
return true
1092+
return &MiddlewareError{}
10931093
}
10941094
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
10951095
w := httptest.NewRecorder()
@@ -1102,14 +1102,14 @@ func TestMiddlewareProxyTerminate(t *testing.T) {
11021102
}
11031103

11041104
// Ensures that middlewareProxy invokes middleware and delegates to the wrapped
1105-
// http.Handler if all middleware return false.
1105+
// http.Handler if all middleware return nil.
11061106
func TestMiddlewareProxyDelegate(t *testing.T) {
11071107
assert := assert.New(t)
11081108
handler := &httpHandler{}
11091109
called := false
1110-
middleware := func(w http.ResponseWriter, r *http.Request) bool {
1110+
middleware := func(w http.ResponseWriter, r *http.Request) *MiddlewareError {
11111111
called = true
1112-
return false
1112+
return nil
11131113
}
11141114
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
11151115
w := httptest.NewRecorder()

rest/example_middleware_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ func HandlerMiddleware(next http.Handler) http.Handler {
7878
}
7979

8080
// Global API middleware is implemented as a function which takes an
81-
// http.ResponseWriter and http.Request and returns a bool indicating if the
82-
// request should terminate or not.
83-
func GlobalMiddleware(w http.ResponseWriter, r *http.Request) bool {
81+
// http.ResponseWriter and http.Request and returns a MiddlewareError if the
82+
// request should terminate.
83+
func GlobalMiddleware(w http.ResponseWriter, r *http.Request) *MiddlewareError {
8484
log.Println(r)
85-
return false
85+
return nil
8686
}
8787

8888
// This example shows how to implement request middleware. ResourceHandlers

rest/middleware/cors.go

+74-11
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,79 @@
11
package middleware
22

3-
import "net/http"
4-
5-
// CORSMiddleware enables cross-origin requests. It implements the Middleware
6-
// interface.
7-
func CORSMiddleware(w http.ResponseWriter, r *http.Request) bool {
8-
if origin := r.Header.Get("Origin"); origin != "" {
9-
w.Header().Set("Access-Control-Allow-Origin", origin)
10-
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
11-
w.Header()["Access-Control-Allow-Headers"] = r.Header["Access-Control-Request-Headers"]
12-
w.Header().Set("Access-Control-Allow-Credentials", "true")
3+
import (
4+
"net/http"
5+
"net/url"
6+
"strings"
7+
8+
"github.com/Workiva/go-rest/rest"
9+
)
10+
11+
// NewCORSMiddleware returns a Middleware which enables cross-origin requests.
12+
// Origin must match the supplied whitelist (which supports wildcards). Returns
13+
// a MiddlewareError if the request should be terminated.
14+
func NewCORSMiddleware(originWhitelist []string) rest.Middleware {
15+
return func(w http.ResponseWriter, r *http.Request) *rest.MiddlewareError {
16+
originMatch := false
17+
if origin := r.Header.Get("Origin"); checkOrigin(origin, originWhitelist) {
18+
w.Header().Set("Access-Control-Allow-Origin", origin)
19+
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
20+
w.Header()["Access-Control-Allow-Headers"] = r.Header["Access-Control-Request-Headers"]
21+
w.Header().Set("Access-Control-Allow-Credentials", "true")
22+
originMatch = true
23+
}
24+
25+
var err *rest.MiddlewareError
26+
if r.Method == "OPTIONS" {
27+
err = &rest.MiddlewareError{Code: http.StatusOK}
28+
} else if !originMatch {
29+
err = &rest.MiddlewareError{
30+
Code: http.StatusBadRequest,
31+
Response: []byte("Origin does not match whitelist"),
32+
}
33+
}
34+
return err
35+
}
36+
}
37+
38+
// checkOrigin checks if the given origin is contained in the origin whitelist.
39+
// Returns true if the origin is in the whitelist, false if not.
40+
func checkOrigin(origin string, whitelist []string) bool {
41+
url, err := url.Parse(origin)
42+
if err != nil {
43+
return false
44+
}
45+
originComponents := strings.Split(url.Host, ".")
46+
47+
checkWhitelist:
48+
for _, whitelisted := range whitelist {
49+
if whitelisted == "*" {
50+
return true
51+
}
52+
53+
whitelistedComponents := strings.Split(whitelisted, ".")
54+
55+
if len(originComponents) != len(whitelistedComponents) {
56+
// Do not match, try next host in whitelist.
57+
continue
58+
}
59+
60+
for i, originComponent := range originComponents {
61+
whitelistedComponent := whitelistedComponents[i]
62+
if whitelistedComponent == "*" {
63+
// Wildcard, check next component.
64+
continue
65+
}
66+
67+
if originComponent != whitelistedComponent {
68+
// Mismatch, try next host in whitelist.
69+
continue checkWhitelist
70+
}
71+
}
72+
73+
// Origin matches whitelisted domain.
74+
return true
1375
}
1476

15-
return r.Method == "OPTIONS"
77+
// No matches.
78+
return false
1679
}

rest/middleware/cors_test.go

+42-9
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,59 @@ import (
88
"github.com/stretchr/testify/assert"
99
)
1010

11-
// Ensures that CORSMiddleware applies the headers needed for CORS and returns
12-
// true for non-OPTIONS requests.
13-
func TestCORSMiddleware(t *testing.T) {
11+
// Ensures that CORSMiddleware applies the headers needed for CORS when * is
12+
// present in the whitelist.
13+
func TestCORSMiddlewareAll(t *testing.T) {
1414
assert := assert.New(t)
1515
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
16-
req.Header.Set("Origin", "abc")
16+
req.Header.Set("Origin", "http://foo.com")
1717
req.Header.Set("Access-Control-Request-Headers", "def")
1818
w := httptest.NewRecorder()
19-
assert.False(CORSMiddleware(w, req))
19+
assert.Nil(NewCORSMiddleware([]string{"*"})(w, req))
2020

21-
assert.Equal("abc", w.Header().Get("Access-Control-Allow-Origin"))
21+
assert.Equal("http://foo.com", w.Header().Get("Access-Control-Allow-Origin"))
2222
assert.Equal("POST, GET, OPTIONS, PUT, DELETE", w.Header().Get("Access-Control-Allow-Methods"))
2323
assert.Equal([]string{"def"}, w.Header()["Access-Control-Allow-Headers"])
2424
assert.Equal([]string{"true"}, w.Header()["Access-Control-Allow-Credentials"])
2525
}
2626

27-
// Ensures that CORSMiddleware returns true for OPTIONS requests.
27+
// Ensures that CORSMiddleware applies the headers needed for CORS and respects
28+
// the origin whitelist.
29+
func TestCORSMiddlewareWhitelist(t *testing.T) {
30+
assert := assert.New(t)
31+
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
32+
req.Header.Set("Origin", "http://foo.wdesk.com")
33+
req.Header.Set("Access-Control-Request-Headers", "def")
34+
w := httptest.NewRecorder()
35+
middleware := NewCORSMiddleware([]string{"blah.wdesk.org", "*.wdesk.com"})
36+
assert.Nil(middleware(w, req))
37+
38+
assert.Equal("http://foo.wdesk.com", w.Header().Get("Access-Control-Allow-Origin"))
39+
assert.Equal("POST, GET, OPTIONS, PUT, DELETE", w.Header().Get("Access-Control-Allow-Methods"))
40+
assert.Equal([]string{"def"}, w.Header()["Access-Control-Allow-Headers"])
41+
assert.Equal([]string{"true"}, w.Header()["Access-Control-Allow-Credentials"])
42+
43+
// Mismatched origin
44+
req, _ = http.NewRequest("GET", "http://example.com/foo", nil)
45+
req.Header.Set("Origin", "http://baz.wdesk.org")
46+
req.Header.Set("Access-Control-Request-Headers", "def")
47+
w = httptest.NewRecorder()
48+
err := middleware(w, req)
49+
assert.NotNil(err)
50+
assert.Equal(http.StatusBadRequest, err.Code)
51+
52+
assert.Equal("", w.Header().Get("Access-Control-Allow-Origin"))
53+
assert.Equal("", w.Header().Get("Access-Control-Allow-Methods"))
54+
assert.Nil(w.Header()["Access-Control-Allow-Headers"])
55+
assert.Nil(w.Header()["Access-Control-Allow-Credentials"])
56+
}
57+
58+
// Ensures that CORSMiddleware returns a MiddlewareError with a 200 response
59+
// code.
2860
func TestCORSMiddlewareOptionsRequest(t *testing.T) {
2961
req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil)
30-
req.Header.Set("Origin", "abc")
62+
req.Header.Set("Origin", "http://foo.com")
3163
w := httptest.NewRecorder()
32-
assert.True(t, CORSMiddleware(w, req))
64+
err := NewCORSMiddleware([]string{"foo.com"})(w, req)
65+
assert.Equal(t, http.StatusOK, err.Code)
3366
}

0 commit comments

Comments
 (0)