@@ -31,6 +31,37 @@ func init() {
31
31
* Tests
32
32
*/
33
33
34
+ func TestServerRootHandler (t * testing.T ) {
35
+ assert := assert .New (t )
36
+ config = newDefaultConfig ()
37
+
38
+ // X-Forwarded headers should be read into request
39
+ req := httptest .NewRequest ("POST" , "http://should-use-x-forwarded.com/should?ignore=me" , nil )
40
+ req .Header .Add ("X-Forwarded-Method" , "GET" )
41
+ req .Header .Add ("X-Forwarded-Proto" , "https" )
42
+ req .Header .Add ("X-Forwarded-Host" , "example.com" )
43
+ req .Header .Add ("X-Forwarded-Uri" , "/foo?q=bar" )
44
+ NewServer ().RootHandler (httptest .NewRecorder (), req )
45
+
46
+ assert .Equal ("GET" , req .Method , "x-forwarded-method should be read into request" )
47
+ assert .Equal ("example.com" , req .Host , "x-forwarded-host should be read into request" )
48
+ assert .Equal ("/foo" , req .URL .Path , "x-forwarded-uri should be read into request" )
49
+ assert .Equal ("/foo?q=bar" , req .URL .RequestURI (), "x-forwarded-uri should be read into request" )
50
+
51
+ // Other X-Forwarded headers should be read in into request and original URL
52
+ // should be preserved if X-Forwarded-Uri not present
53
+ req = httptest .NewRequest ("POST" , "http://should-use-x-forwarded.com/should-not?ignore=me" , nil )
54
+ req .Header .Add ("X-Forwarded-Method" , "GET" )
55
+ req .Header .Add ("X-Forwarded-Proto" , "https" )
56
+ req .Header .Add ("X-Forwarded-Host" , "example.com" )
57
+ NewServer ().RootHandler (httptest .NewRecorder (), req )
58
+
59
+ assert .Equal ("GET" , req .Method , "x-forwarded-method should be read into request" )
60
+ assert .Equal ("example.com" , req .Host , "x-forwarded-host should be read into request" )
61
+ assert .Equal ("/should-not" , req .URL .Path , "request url should be preserved if x-forwarded-uri not present" )
62
+ assert .Equal ("/should-not?ignore=me" , req .URL .RequestURI (), "request url should be preserved if x-forwarded-uri not present" )
63
+ }
64
+
34
65
func TestServerAuthHandlerInvalid (t * testing.T ) {
35
66
assert := assert .New (t )
36
67
config = newDefaultConfig ()
@@ -90,10 +121,10 @@ func TestServerAuthHandlerExpired(t *testing.T) {
90
121
config .Domains = []string {"test.com" }
91
122
92
123
// Should redirect expired cookie
93
- req := newDefaultHttpRequest ( " /foo" )
124
+ req := newHTTPRequest ( "GET" , "http://example.com /foo" )
94
125
c := MakeCookie (
req ,
"[email protected] " )
95
126
res , _ := doHttpRequest (req , c )
96
- assert .Equal (307 , res .StatusCode , "request with expired cookie should be redirected" )
127
+ require .Equal (t , 307 , res .StatusCode , "request with expired cookie should be redirected" )
97
128
98
129
// Check for CSRF cookie
99
130
var cookie * http.Cookie
@@ -116,7 +147,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
116
147
config = newDefaultConfig ()
117
148
118
149
// Should allow valid request email
119
- req := newDefaultHttpRequest ( " /foo" )
150
+ req := newHTTPRequest ( "GET" , "http://example.com /foo" )
120
151
c := MakeCookie (
req ,
"[email protected] " )
121
152
config .Domains = []string {}
122
153
@@ -131,6 +162,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
131
162
132
163
func TestServerAuthCallback (t * testing.T ) {
133
164
assert := assert .New (t )
165
+ require := require .New (t )
134
166
config = newDefaultConfig ()
135
167
136
168
// Setup OAuth server
@@ -148,27 +180,28 @@ func TestServerAuthCallback(t *testing.T) {
148
180
}
149
181
150
182
// Should pass auth response request to callback
151
- req := newDefaultHttpRequest ( " /_oauth" )
183
+ req := newHTTPRequest ( "GET" , "http://example.com /_oauth" )
152
184
res , _ := doHttpRequest (req , nil )
153
185
assert .Equal (401 , res .StatusCode , "auth callback without cookie shouldn't be authorised" )
154
186
155
187
// Should catch invalid csrf cookie
156
- req = newDefaultHttpRequest ("/_oauth?state=12345678901234567890123456789012:http://redirect" )
188
+ nonce := "12345678901234567890123456789012"
189
+ req = newHTTPRequest ("GET" , "http://example.com/_oauth?state=" + nonce + ":http://redirect" )
157
190
c := MakeCSRFCookie (req , "nononononononononononononononono" )
158
191
res , _ = doHttpRequest (req , c )
159
192
assert .Equal (401 , res .StatusCode , "auth callback with invalid cookie shouldn't be authorised" )
160
193
161
194
// Should catch invalid provider cookie
162
- req = newDefaultHttpRequest ( "/ _oauth?state=12345678901234567890123456789012 :invalid:http://redirect" )
163
- c = MakeCSRFCookie (req , "12345678901234567890123456789012" )
195
+ req = newHTTPRequest ( "GET" , "http://example.com/ _oauth?state=" + nonce + " :invalid:http://redirect" )
196
+ c = MakeCSRFCookie (req , nonce )
164
197
res , _ = doHttpRequest (req , c )
165
198
assert .Equal (401 , res .StatusCode , "auth callback with invalid provider shouldn't be authorised" )
166
199
167
200
// Should redirect valid request
168
- req = newDefaultHttpRequest ( "/ _oauth?state=12345678901234567890123456789012 :google:http://redirect" )
169
- c = MakeCSRFCookie (req , "12345678901234567890123456789012" )
201
+ req = newHTTPRequest ( "GET" , "http://example.com/ _oauth?state=" + nonce + " :google:http://redirect" )
202
+ c = MakeCSRFCookie (req , nonce )
170
203
res , _ = doHttpRequest (req , c )
171
- assert .Equal (307 , res .StatusCode , "valid auth callback should be allowed" )
204
+ require .Equal (307 , res .StatusCode , "valid auth callback should be allowed" )
172
205
173
206
fwd , _ := res .Location ()
174
207
assert .Equal ("http" , fwd .Scheme , "valid request should be redirected to return url" )
@@ -360,17 +393,17 @@ func TestServerRouteHost(t *testing.T) {
360
393
}
361
394
362
395
// Should block any request
363
- req := newHttpRequest ("GET" , "https://example.com/" , " /" )
396
+ req := newHTTPRequest ("GET" , "https://example.com/" )
364
397
res , _ := doHttpRequest (req , nil )
365
398
assert .Equal (307 , res .StatusCode , "request not matching any rule should require auth" )
366
399
367
400
// Should allow matching request
368
- req = newHttpRequest ("GET" , "https://api.example.com/" , " /" )
401
+ req = newHTTPRequest ("GET" , "https://api.example.com/" )
369
402
res , _ = doHttpRequest (req , nil )
370
403
assert .Equal (200 , res .StatusCode , "request matching allow rule should be allowed" )
371
404
372
405
// Should allow matching request
373
- req = newHttpRequest ("GET" , "https://sub8.example.com/" , " /" )
406
+ req = newHTTPRequest ("GET" , "https://sub8.example.com/" )
374
407
res , _ = doHttpRequest (req , nil )
375
408
assert .Equal (200 , res .StatusCode , "request matching allow rule should be allowed" )
376
409
}
@@ -386,12 +419,12 @@ func TestServerRouteMethod(t *testing.T) {
386
419
}
387
420
388
421
// Should block any request
389
- req := newHttpRequest ("GET" , "https://example.com/" , " /" )
422
+ req := newHTTPRequest ("GET" , "https://example.com/" )
390
423
res , _ := doHttpRequest (req , nil )
391
424
assert .Equal (307 , res .StatusCode , "request not matching any rule should require auth" )
392
425
393
426
// Should allow matching request
394
- req = newHttpRequest ("PUT" , "https://example.com/" , " /" )
427
+ req = newHTTPRequest ("PUT" , "https://example.com/" )
395
428
res , _ = doHttpRequest (req , nil )
396
429
assert .Equal (200 , res .StatusCode , "request matching allow rule should be allowed" )
397
430
}
@@ -441,12 +474,12 @@ func TestServerRouteQuery(t *testing.T) {
441
474
}
442
475
443
476
// Should block any request
444
- req := newHttpRequest ("GET" , "https://example.com/" , " /?q=no" )
477
+ req := newHTTPRequest ("GET" , "https://example.com/?q=no" )
445
478
res , _ := doHttpRequest (req , nil )
446
479
assert .Equal (307 , res .StatusCode , "request not matching any rule should require auth" )
447
480
448
481
// Should allow matching request
449
- req = newHttpRequest ("GET" , "https://api.example.com/" , " /?q=test123" )
482
+ req = newHTTPRequest ("GET" , "https://api.example.com/?q=test123" )
450
483
res , _ = doHttpRequest (req , nil )
451
484
assert .Equal (200 , res .StatusCode , "request matching allow rule should be allowed" )
452
485
}
@@ -531,16 +564,17 @@ func newDefaultConfig() *Config {
531
564
return config
532
565
}
533
566
567
+ // TODO: replace with newHTTPRequest("GET", "http://example.com/"+uri)
534
568
func newDefaultHttpRequest (uri string ) * http.Request {
535
- return newHttpRequest ( " " , "http://example.com/" , uri )
569
+ return newHTTPRequest ( "GET " , "http://example.com" + uri )
536
570
}
537
571
538
- func newHttpRequest (method , dest , uri string ) * http.Request {
539
- r := httptest . NewRequest ( "" , "http://should-use-x-forwarded.com" , nil )
540
- p , _ := url . Parse ( dest )
572
+ func newHTTPRequest (method , target string ) * http.Request {
573
+ u , _ := url . Parse ( target )
574
+ r := httptest . NewRequest ( method , target , nil )
541
575
r .Header .Add ("X-Forwarded-Method" , method )
542
- r .Header .Add ("X-Forwarded-Proto" , p .Scheme )
543
- r .Header .Add ("X-Forwarded-Host" , p .Host )
544
- r .Header .Add ("X-Forwarded-Uri" , uri )
576
+ r .Header .Add ("X-Forwarded-Proto" , u .Scheme )
577
+ r .Header .Add ("X-Forwarded-Host" , u .Host )
578
+ r .Header .Add ("X-Forwarded-Uri" , u . RequestURI () )
545
579
return r
546
580
}
0 commit comments