forked from Beanow/traefik-forward-auth
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow to be run without middleware + improve request reading consiste…
…ncy (thomseddon#217) Prior to this change, the request URI was only ever read from the X-Forwarded-Uri header which was only set when the container was accessed via the forwardauth middleware. As such, it was necessary to apply the treafik-forward-auth middleware to the treafik-forward-auth container when running auth host mode. This is a quirk, unnecessary complexity and is a frequent source of configuration issues.
- Loading branch information
1 parent
e49e5c7
commit 0b3d77d
Showing
6 changed files
with
74 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,37 @@ func init() { | |
* Tests | ||
*/ | ||
|
||
func TestServerRootHandler(t *testing.T) { | ||
assert := assert.New(t) | ||
config = newDefaultConfig() | ||
|
||
// X-Forwarded headers should be read into request | ||
req := httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should?ignore=me", nil) | ||
req.Header.Add("X-Forwarded-Method", "GET") | ||
req.Header.Add("X-Forwarded-Proto", "https") | ||
req.Header.Add("X-Forwarded-Host", "example.com") | ||
req.Header.Add("X-Forwarded-Uri", "/foo?q=bar") | ||
NewServer().RootHandler(httptest.NewRecorder(), req) | ||
|
||
assert.Equal("GET", req.Method, "x-forwarded-method should be read into request") | ||
assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request") | ||
assert.Equal("/foo", req.URL.Path, "x-forwarded-uri should be read into request") | ||
assert.Equal("/foo?q=bar", req.URL.RequestURI(), "x-forwarded-uri should be read into request") | ||
|
||
// Other X-Forwarded headers should be read in into request and original URL | ||
// should be preserved if X-Forwarded-Uri not present | ||
req = httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should-not?ignore=me", nil) | ||
req.Header.Add("X-Forwarded-Method", "GET") | ||
req.Header.Add("X-Forwarded-Proto", "https") | ||
req.Header.Add("X-Forwarded-Host", "example.com") | ||
NewServer().RootHandler(httptest.NewRecorder(), req) | ||
|
||
assert.Equal("GET", req.Method, "x-forwarded-method should be read into request") | ||
assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request") | ||
assert.Equal("/should-not", req.URL.Path, "request url should be preserved if x-forwarded-uri not present") | ||
assert.Equal("/should-not?ignore=me", req.URL.RequestURI(), "request url should be preserved if x-forwarded-uri not present") | ||
} | ||
|
||
func TestServerAuthHandlerInvalid(t *testing.T) { | ||
assert := assert.New(t) | ||
config = newDefaultConfig() | ||
|
@@ -90,10 +121,10 @@ func TestServerAuthHandlerExpired(t *testing.T) { | |
config.Domains = []string{"test.com"} | ||
|
||
// Should redirect expired cookie | ||
req := newDefaultHttpRequest("/foo") | ||
req := newHTTPRequest("GET", "http://example.com/foo") | ||
c := MakeCookie(req, "[email protected]") | ||
res, _ := doHttpRequest(req, c) | ||
assert.Equal(307, res.StatusCode, "request with expired cookie should be redirected") | ||
require.Equal(t, 307, res.StatusCode, "request with expired cookie should be redirected") | ||
|
||
// Check for CSRF cookie | ||
var cookie *http.Cookie | ||
|
@@ -116,7 +147,7 @@ func TestServerAuthHandlerValid(t *testing.T) { | |
config = newDefaultConfig() | ||
|
||
// Should allow valid request email | ||
req := newDefaultHttpRequest("/foo") | ||
req := newHTTPRequest("GET", "http://example.com/foo") | ||
c := MakeCookie(req, "[email protected]") | ||
config.Domains = []string{} | ||
|
||
|
@@ -131,6 +162,7 @@ func TestServerAuthHandlerValid(t *testing.T) { | |
|
||
func TestServerAuthCallback(t *testing.T) { | ||
assert := assert.New(t) | ||
require := require.New(t) | ||
config = newDefaultConfig() | ||
|
||
// Setup OAuth server | ||
|
@@ -148,27 +180,28 @@ func TestServerAuthCallback(t *testing.T) { | |
} | ||
|
||
// Should pass auth response request to callback | ||
req := newDefaultHttpRequest("/_oauth") | ||
req := newHTTPRequest("GET", "http://example.com/_oauth") | ||
res, _ := doHttpRequest(req, nil) | ||
assert.Equal(401, res.StatusCode, "auth callback without cookie shouldn't be authorised") | ||
|
||
// Should catch invalid csrf cookie | ||
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect") | ||
nonce := "12345678901234567890123456789012" | ||
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":http://redirect") | ||
c := MakeCSRFCookie(req, "nononononononononononononononono") | ||
res, _ = doHttpRequest(req, c) | ||
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised") | ||
|
||
// Should catch invalid provider cookie | ||
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:invalid:http://redirect") | ||
c = MakeCSRFCookie(req, "12345678901234567890123456789012") | ||
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":invalid:http://redirect") | ||
c = MakeCSRFCookie(req, nonce) | ||
res, _ = doHttpRequest(req, c) | ||
assert.Equal(401, res.StatusCode, "auth callback with invalid provider shouldn't be authorised") | ||
|
||
// Should redirect valid request | ||
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect") | ||
c = MakeCSRFCookie(req, "12345678901234567890123456789012") | ||
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":google:http://redirect") | ||
c = MakeCSRFCookie(req, nonce) | ||
res, _ = doHttpRequest(req, c) | ||
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed") | ||
require.Equal(307, res.StatusCode, "valid auth callback should be allowed") | ||
|
||
fwd, _ := res.Location() | ||
assert.Equal("http", fwd.Scheme, "valid request should be redirected to return url") | ||
|
@@ -360,17 +393,17 @@ func TestServerRouteHost(t *testing.T) { | |
} | ||
|
||
// Should block any request | ||
req := newHttpRequest("GET", "https://example.com/", "/") | ||
req := newHTTPRequest("GET", "https://example.com/") | ||
res, _ := doHttpRequest(req, nil) | ||
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") | ||
|
||
// Should allow matching request | ||
req = newHttpRequest("GET", "https://api.example.com/", "/") | ||
req = newHTTPRequest("GET", "https://api.example.com/") | ||
res, _ = doHttpRequest(req, nil) | ||
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") | ||
|
||
// Should allow matching request | ||
req = newHttpRequest("GET", "https://sub8.example.com/", "/") | ||
req = newHTTPRequest("GET", "https://sub8.example.com/") | ||
res, _ = doHttpRequest(req, nil) | ||
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") | ||
} | ||
|
@@ -386,12 +419,12 @@ func TestServerRouteMethod(t *testing.T) { | |
} | ||
|
||
// Should block any request | ||
req := newHttpRequest("GET", "https://example.com/", "/") | ||
req := newHTTPRequest("GET", "https://example.com/") | ||
res, _ := doHttpRequest(req, nil) | ||
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") | ||
|
||
// Should allow matching request | ||
req = newHttpRequest("PUT", "https://example.com/", "/") | ||
req = newHTTPRequest("PUT", "https://example.com/") | ||
res, _ = doHttpRequest(req, nil) | ||
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") | ||
} | ||
|
@@ -441,12 +474,12 @@ func TestServerRouteQuery(t *testing.T) { | |
} | ||
|
||
// Should block any request | ||
req := newHttpRequest("GET", "https://example.com/", "/?q=no") | ||
req := newHTTPRequest("GET", "https://example.com/?q=no") | ||
res, _ := doHttpRequest(req, nil) | ||
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth") | ||
|
||
// Should allow matching request | ||
req = newHttpRequest("GET", "https://api.example.com/", "/?q=test123") | ||
req = newHTTPRequest("GET", "https://api.example.com/?q=test123") | ||
res, _ = doHttpRequest(req, nil) | ||
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed") | ||
} | ||
|
@@ -531,16 +564,17 @@ func newDefaultConfig() *Config { | |
return config | ||
} | ||
|
||
// TODO: replace with newHTTPRequest("GET", "http://example.com/"+uri) | ||
func newDefaultHttpRequest(uri string) *http.Request { | ||
return newHttpRequest("", "http://example.com/", uri) | ||
return newHTTPRequest("GET", "http://example.com"+uri) | ||
} | ||
|
||
func newHttpRequest(method, dest, uri string) *http.Request { | ||
r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil) | ||
p, _ := url.Parse(dest) | ||
func newHTTPRequest(method, target string) *http.Request { | ||
u, _ := url.Parse(target) | ||
r := httptest.NewRequest(method, target, nil) | ||
r.Header.Add("X-Forwarded-Method", method) | ||
r.Header.Add("X-Forwarded-Proto", p.Scheme) | ||
r.Header.Add("X-Forwarded-Host", p.Host) | ||
r.Header.Add("X-Forwarded-Uri", uri) | ||
r.Header.Add("X-Forwarded-Proto", u.Scheme) | ||
r.Header.Add("X-Forwarded-Host", u.Host) | ||
r.Header.Add("X-Forwarded-Uri", u.RequestURI()) | ||
return r | ||
} |