Skip to content

Commit

Permalink
Allow to be run without middleware + improve request reading consiste…
Browse files Browse the repository at this point in the history
…ncy (thomseddon#217)

Prior to this change, the request URI was only ever read from the
X-Forwarded-Uri header which was only set when the container was
accessed via the forwardauth middleware. As such, it was necessary
to apply the treafik-forward-auth middleware to the treafik-forward-auth
container when running auth host mode.
This is a quirk, unnecessary complexity and is a frequent source of
configuration issues.
  • Loading branch information
thomseddon authored and mkska committed Aug 22, 2023
1 parent e49e5c7 commit 0b3d77d
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 54 deletions.
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,6 @@ spec:
- name: traefik-forward-auth
```
Note: If using auth host mode, you must apply the middleware to your auth host ingress.
See the examples directory for more examples.
#### Selective Container Authentication in Swarm
Expand All @@ -449,8 +447,6 @@ whoami:
- "traefik.http.routers.whoami.middlewares=traefik-forward-auth"
```
Note: If using auth host mode, you must apply the middleware to the traefik-forward-auth container.
See the examples directory for more examples.
#### Rules Based Authentication
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,5 @@ spec:
services:
- name: traefik-forward-auth
port: 4181
middlewares:
- name: traefik-forward-auth
tls:
certresolver: default
21 changes: 7 additions & 14 deletions internal/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,24 +125,19 @@ func ValidateDomains(user string, domains CommaSeparatedList) bool {

// Get the redirect base
func redirectBase(r *http.Request) string {
proto := r.Header.Get("X-Forwarded-Proto")
host := r.Header.Get("X-Forwarded-Host")

return fmt.Sprintf("%s://%s", proto, host)
return fmt.Sprintf("%s://%s", r.Header.Get("X-Forwarded-Proto"), r.Host)
}

// Return url
func returnUrl(r *http.Request) string {
path := r.Header.Get("X-Forwarded-Uri")

return fmt.Sprintf("%s%s", redirectBase(r), path)
return fmt.Sprintf("%s%s", redirectBase(r), r.URL.Path)
}

// Get oauth redirect uri
func redirectUri(r *http.Request) string {
if use, _ := useAuthDomain(r); use {
proto := r.Header.Get("X-Forwarded-Proto")
return fmt.Sprintf("%s://%s%s", proto, config.AuthHost, config.Path)
p := r.Header.Get("X-Forwarded-Proto")
return fmt.Sprintf("%s://%s%s", p, config.AuthHost, config.Path)
}

return fmt.Sprintf("%s%s", redirectBase(r), config.Path)
Expand All @@ -155,7 +150,7 @@ func useAuthDomain(r *http.Request) (bool, string) {
}

// Does the request match a given cookie domain?
reqMatch, reqHost := matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
reqMatch, reqHost := matchCookieDomains(r.Host)

// Do any of the auth hosts match a cookie domain?
authMatch, authHost := matchCookieDomains(config.AuthHost)
Expand Down Expand Up @@ -284,10 +279,8 @@ func Nonce() (error, string) {

// Cookie domain
func cookieDomain(r *http.Request) string {
host := r.Header.Get("X-Forwarded-Host")

// Check if any of the given cookie domains matches
_, domain := matchCookieDomains(host)
_, domain := matchCookieDomains(r.Host)
return domain
}

Expand All @@ -297,7 +290,7 @@ func csrfCookieDomain(r *http.Request) string {
if use, domain := useAuthDomain(r); use {
host = domain
} else {
host = r.Header.Get("X-Forwarded-Host")
host = r.Host
}

// Remove port
Expand Down
13 changes: 4 additions & 9 deletions internal/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tfa

import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
Expand Down Expand Up @@ -212,10 +213,8 @@ func TestAuthValidateUser(t *testing.T) {
func TestRedirectUri(t *testing.T) {
assert := assert.New(t)

r, _ := http.NewRequest("GET", "http://example.com", nil)
r := httptest.NewRequest("GET", "http://app.example.com/hello", nil)
r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Host", "app.example.com")
r.Header.Add("X-Forwarded-Uri", "/hello")

//
// No Auth Host
Expand Down Expand Up @@ -257,10 +256,8 @@ func TestRedirectUri(t *testing.T) {
// With Auth URL + cookie domain, but from different domain
// - will not use auth host
//
r, _ = http.NewRequest("GET", "http://another.com", nil)
r = httptest.NewRequest("GET", "https://another.com/hello", nil)
r.Header.Add("X-Forwarded-Proto", "https")
r.Header.Add("X-Forwarded-Host", "another.com")
r.Header.Add("X-Forwarded-Uri", "/hello")

config.AuthHost = "auth.example.com"
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
Expand Down Expand Up @@ -394,10 +391,8 @@ func TestValidateState(t *testing.T) {
func TestMakeState(t *testing.T) {
assert := assert.New(t)

r, _ := http.NewRequest("GET", "http://example.com", nil)
r := httptest.NewRequest("GET", "http://example.com/hello", nil)
r.Header.Add("X-Forwarded-Proto", "http")
r.Header.Add("X-Forwarded-Host", "example.com")
r.Header.Add("X-Forwarded-Uri", "/hello")

// Test with google
p := provider.Google{}
Expand Down
6 changes: 5 additions & 1 deletion internal/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
// Modify request
r.Method = r.Header.Get("X-Forwarded-Method")
r.Host = r.Header.Get("X-Forwarded-Host")
r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri"))

// Read URI from header if we're acting as forward auth middleware
if _, ok := r.Header["X-Forwarded-Uri"]; ok {
r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri"))
}

// Pass to mux
s.router.ServeHTTP(w, r)
Expand Down
82 changes: 58 additions & 24 deletions internal/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,37 @@ func init() {
* Tests
*/

func TestServerRootHandler(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()

// X-Forwarded headers should be read into request
req := httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should?ignore=me", nil)
req.Header.Add("X-Forwarded-Method", "GET")
req.Header.Add("X-Forwarded-Proto", "https")
req.Header.Add("X-Forwarded-Host", "example.com")
req.Header.Add("X-Forwarded-Uri", "/foo?q=bar")
NewServer().RootHandler(httptest.NewRecorder(), req)

assert.Equal("GET", req.Method, "x-forwarded-method should be read into request")
assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request")
assert.Equal("/foo", req.URL.Path, "x-forwarded-uri should be read into request")
assert.Equal("/foo?q=bar", req.URL.RequestURI(), "x-forwarded-uri should be read into request")

// Other X-Forwarded headers should be read in into request and original URL
// should be preserved if X-Forwarded-Uri not present
req = httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should-not?ignore=me", nil)
req.Header.Add("X-Forwarded-Method", "GET")
req.Header.Add("X-Forwarded-Proto", "https")
req.Header.Add("X-Forwarded-Host", "example.com")
NewServer().RootHandler(httptest.NewRecorder(), req)

assert.Equal("GET", req.Method, "x-forwarded-method should be read into request")
assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request")
assert.Equal("/should-not", req.URL.Path, "request url should be preserved if x-forwarded-uri not present")
assert.Equal("/should-not?ignore=me", req.URL.RequestURI(), "request url should be preserved if x-forwarded-uri not present")
}

func TestServerAuthHandlerInvalid(t *testing.T) {
assert := assert.New(t)
config = newDefaultConfig()
Expand Down Expand Up @@ -90,10 +121,10 @@ func TestServerAuthHandlerExpired(t *testing.T) {
config.Domains = []string{"test.com"}

// Should redirect expired cookie
req := newDefaultHttpRequest("/foo")
req := newHTTPRequest("GET", "http://example.com/foo")
c := MakeCookie(req, "[email protected]")
res, _ := doHttpRequest(req, c)
assert.Equal(307, res.StatusCode, "request with expired cookie should be redirected")
require.Equal(t, 307, res.StatusCode, "request with expired cookie should be redirected")

// Check for CSRF cookie
var cookie *http.Cookie
Expand All @@ -116,7 +147,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
config = newDefaultConfig()

// Should allow valid request email
req := newDefaultHttpRequest("/foo")
req := newHTTPRequest("GET", "http://example.com/foo")
c := MakeCookie(req, "[email protected]")
config.Domains = []string{}

Expand All @@ -131,6 +162,7 @@ func TestServerAuthHandlerValid(t *testing.T) {

func TestServerAuthCallback(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
config = newDefaultConfig()

// Setup OAuth server
Expand All @@ -148,27 +180,28 @@ func TestServerAuthCallback(t *testing.T) {
}

// Should pass auth response request to callback
req := newDefaultHttpRequest("/_oauth")
req := newHTTPRequest("GET", "http://example.com/_oauth")
res, _ := doHttpRequest(req, nil)
assert.Equal(401, res.StatusCode, "auth callback without cookie shouldn't be authorised")

// Should catch invalid csrf cookie
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
nonce := "12345678901234567890123456789012"
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":http://redirect")
c := MakeCSRFCookie(req, "nononononononononononononononono")
res, _ = doHttpRequest(req, c)
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")

// Should catch invalid provider cookie
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:invalid:http://redirect")
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":invalid:http://redirect")
c = MakeCSRFCookie(req, nonce)
res, _ = doHttpRequest(req, c)
assert.Equal(401, res.StatusCode, "auth callback with invalid provider shouldn't be authorised")

// Should redirect valid request
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":google:http://redirect")
c = MakeCSRFCookie(req, nonce)
res, _ = doHttpRequest(req, c)
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed")
require.Equal(307, res.StatusCode, "valid auth callback should be allowed")

fwd, _ := res.Location()
assert.Equal("http", fwd.Scheme, "valid request should be redirected to return url")
Expand Down Expand Up @@ -360,17 +393,17 @@ func TestServerRouteHost(t *testing.T) {
}

// Should block any request
req := newHttpRequest("GET", "https://example.com/", "/")
req := newHTTPRequest("GET", "https://example.com/")
res, _ := doHttpRequest(req, nil)
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")

// Should allow matching request
req = newHttpRequest("GET", "https://api.example.com/", "/")
req = newHTTPRequest("GET", "https://api.example.com/")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")

// Should allow matching request
req = newHttpRequest("GET", "https://sub8.example.com/", "/")
req = newHTTPRequest("GET", "https://sub8.example.com/")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
}
Expand All @@ -386,12 +419,12 @@ func TestServerRouteMethod(t *testing.T) {
}

// Should block any request
req := newHttpRequest("GET", "https://example.com/", "/")
req := newHTTPRequest("GET", "https://example.com/")
res, _ := doHttpRequest(req, nil)
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")

// Should allow matching request
req = newHttpRequest("PUT", "https://example.com/", "/")
req = newHTTPRequest("PUT", "https://example.com/")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
}
Expand Down Expand Up @@ -441,12 +474,12 @@ func TestServerRouteQuery(t *testing.T) {
}

// Should block any request
req := newHttpRequest("GET", "https://example.com/", "/?q=no")
req := newHTTPRequest("GET", "https://example.com/?q=no")
res, _ := doHttpRequest(req, nil)
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")

// Should allow matching request
req = newHttpRequest("GET", "https://api.example.com/", "/?q=test123")
req = newHTTPRequest("GET", "https://api.example.com/?q=test123")
res, _ = doHttpRequest(req, nil)
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
}
Expand Down Expand Up @@ -531,16 +564,17 @@ func newDefaultConfig() *Config {
return config
}

// TODO: replace with newHTTPRequest("GET", "http://example.com/"+uri)
func newDefaultHttpRequest(uri string) *http.Request {
return newHttpRequest("", "http://example.com/", uri)
return newHTTPRequest("GET", "http://example.com"+uri)
}

func newHttpRequest(method, dest, uri string) *http.Request {
r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil)
p, _ := url.Parse(dest)
func newHTTPRequest(method, target string) *http.Request {
u, _ := url.Parse(target)
r := httptest.NewRequest(method, target, nil)
r.Header.Add("X-Forwarded-Method", method)
r.Header.Add("X-Forwarded-Proto", p.Scheme)
r.Header.Add("X-Forwarded-Host", p.Host)
r.Header.Add("X-Forwarded-Uri", uri)
r.Header.Add("X-Forwarded-Proto", u.Scheme)
r.Header.Add("X-Forwarded-Host", u.Host)
r.Header.Add("X-Forwarded-Uri", u.RequestURI())
return r
}

0 comments on commit 0b3d77d

Please sign in to comment.