Skip to content

Commit c4317b7

Browse files
authored
Allow to be run without middleware + improve request reading consistency (#217)
Prior to this change, the request URI was only ever read from the X-Forwarded-Uri header which was only set when the container was accessed via the forwardauth middleware. As such, it was necessary to apply the treafik-forward-auth middleware to the treafik-forward-auth container when running auth host mode. This is a quirk, unnecessary complexity and is a frequent source of configuration issues.
1 parent 4ffb659 commit c4317b7

File tree

6 files changed

+74
-54
lines changed

6 files changed

+74
-54
lines changed

README.md

-4
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,6 @@ spec:
426426
- name: traefik-forward-auth
427427
```
428428
429-
Note: If using auth host mode, you must apply the middleware to your auth host ingress.
430-
431429
See the examples directory for more examples.
432430
433431
#### Selective Container Authentication in Swarm
@@ -442,8 +440,6 @@ whoami:
442440
- "traefik.http.routers.whoami.middlewares=traefik-forward-auth"
443441
```
444442
445-
Note: If using auth host mode, you must apply the middleware to the traefik-forward-auth container.
446-
447443
See the examples directory for more examples.
448444
449445
#### Rules Based Authentication

examples/traefik-v2/kubernetes/advanced-separate-pod/traefik-forward-auth/ingress.yaml

-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,5 @@ spec:
1616
services:
1717
- name: traefik-forward-auth
1818
port: 4181
19-
middlewares:
20-
- name: traefik-forward-auth
2119
tls:
2220
certresolver: default

internal/auth.go

+7-14
Original file line numberDiff line numberDiff line change
@@ -125,24 +125,19 @@ func ValidateDomains(email string, domains CommaSeparatedList) bool {
125125

126126
// Get the redirect base
127127
func redirectBase(r *http.Request) string {
128-
proto := r.Header.Get("X-Forwarded-Proto")
129-
host := r.Header.Get("X-Forwarded-Host")
130-
131-
return fmt.Sprintf("%s://%s", proto, host)
128+
return fmt.Sprintf("%s://%s", r.Header.Get("X-Forwarded-Proto"), r.Host)
132129
}
133130

134131
// Return url
135132
func returnUrl(r *http.Request) string {
136-
path := r.Header.Get("X-Forwarded-Uri")
137-
138-
return fmt.Sprintf("%s%s", redirectBase(r), path)
133+
return fmt.Sprintf("%s%s", redirectBase(r), r.URL.Path)
139134
}
140135

141136
// Get oauth redirect uri
142137
func redirectUri(r *http.Request) string {
143138
if use, _ := useAuthDomain(r); use {
144-
proto := r.Header.Get("X-Forwarded-Proto")
145-
return fmt.Sprintf("%s://%s%s", proto, config.AuthHost, config.Path)
139+
p := r.Header.Get("X-Forwarded-Proto")
140+
return fmt.Sprintf("%s://%s%s", p, config.AuthHost, config.Path)
146141
}
147142

148143
return fmt.Sprintf("%s%s", redirectBase(r), config.Path)
@@ -155,7 +150,7 @@ func useAuthDomain(r *http.Request) (bool, string) {
155150
}
156151

157152
// Does the request match a given cookie domain?
158-
reqMatch, reqHost := matchCookieDomains(r.Header.Get("X-Forwarded-Host"))
153+
reqMatch, reqHost := matchCookieDomains(r.Host)
159154

160155
// Do any of the auth hosts match a cookie domain?
161156
authMatch, authHost := matchCookieDomains(config.AuthHost)
@@ -284,10 +279,8 @@ func Nonce() (error, string) {
284279

285280
// Cookie domain
286281
func cookieDomain(r *http.Request) string {
287-
host := r.Header.Get("X-Forwarded-Host")
288-
289282
// Check if any of the given cookie domains matches
290-
_, domain := matchCookieDomains(host)
283+
_, domain := matchCookieDomains(r.Host)
291284
return domain
292285
}
293286

@@ -297,7 +290,7 @@ func csrfCookieDomain(r *http.Request) string {
297290
if use, domain := useAuthDomain(r); use {
298291
host = domain
299292
} else {
300-
host = r.Header.Get("X-Forwarded-Host")
293+
host = r.Host
301294
}
302295

303296
// Remove port

internal/auth_test.go

+4-9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package tfa
22

33
import (
44
"net/http"
5+
"net/http/httptest"
56
"net/url"
67
"strings"
78
"testing"
@@ -196,10 +197,8 @@ func TestAuthValidateEmail(t *testing.T) {
196197
func TestRedirectUri(t *testing.T) {
197198
assert := assert.New(t)
198199

199-
r, _ := http.NewRequest("GET", "http://example.com", nil)
200+
r := httptest.NewRequest("GET", "http://app.example.com/hello", nil)
200201
r.Header.Add("X-Forwarded-Proto", "http")
201-
r.Header.Add("X-Forwarded-Host", "app.example.com")
202-
r.Header.Add("X-Forwarded-Uri", "/hello")
203202

204203
//
205204
// No Auth Host
@@ -241,10 +240,8 @@ func TestRedirectUri(t *testing.T) {
241240
// With Auth URL + cookie domain, but from different domain
242241
// - will not use auth host
243242
//
244-
r, _ = http.NewRequest("GET", "http://another.com", nil)
243+
r = httptest.NewRequest("GET", "https://another.com/hello", nil)
245244
r.Header.Add("X-Forwarded-Proto", "https")
246-
r.Header.Add("X-Forwarded-Host", "another.com")
247-
r.Header.Add("X-Forwarded-Uri", "/hello")
248245

249246
config.AuthHost = "auth.example.com"
250247
config.CookieDomains = []CookieDomain{*NewCookieDomain("example.com")}
@@ -378,10 +375,8 @@ func TestValidateState(t *testing.T) {
378375
func TestMakeState(t *testing.T) {
379376
assert := assert.New(t)
380377

381-
r, _ := http.NewRequest("GET", "http://example.com", nil)
378+
r := httptest.NewRequest("GET", "http://example.com/hello", nil)
382379
r.Header.Add("X-Forwarded-Proto", "http")
383-
r.Header.Add("X-Forwarded-Host", "example.com")
384-
r.Header.Add("X-Forwarded-Uri", "/hello")
385380

386381
// Test with google
387382
p := provider.Google{}

internal/server.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,11 @@ func (s *Server) RootHandler(w http.ResponseWriter, r *http.Request) {
5858
// Modify request
5959
r.Method = r.Header.Get("X-Forwarded-Method")
6060
r.Host = r.Header.Get("X-Forwarded-Host")
61-
r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri"))
61+
62+
// Read URI from header if we're acting as forward auth middleware
63+
if _, ok := r.Header["X-Forwarded-Uri"]; ok {
64+
r.URL, _ = url.Parse(r.Header.Get("X-Forwarded-Uri"))
65+
}
6266

6367
// Pass to mux
6468
s.router.ServeHTTP(w, r)

internal/server_test.go

+58-24
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,37 @@ func init() {
3131
* Tests
3232
*/
3333

34+
func TestServerRootHandler(t *testing.T) {
35+
assert := assert.New(t)
36+
config = newDefaultConfig()
37+
38+
// X-Forwarded headers should be read into request
39+
req := httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should?ignore=me", nil)
40+
req.Header.Add("X-Forwarded-Method", "GET")
41+
req.Header.Add("X-Forwarded-Proto", "https")
42+
req.Header.Add("X-Forwarded-Host", "example.com")
43+
req.Header.Add("X-Forwarded-Uri", "/foo?q=bar")
44+
NewServer().RootHandler(httptest.NewRecorder(), req)
45+
46+
assert.Equal("GET", req.Method, "x-forwarded-method should be read into request")
47+
assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request")
48+
assert.Equal("/foo", req.URL.Path, "x-forwarded-uri should be read into request")
49+
assert.Equal("/foo?q=bar", req.URL.RequestURI(), "x-forwarded-uri should be read into request")
50+
51+
// Other X-Forwarded headers should be read in into request and original URL
52+
// should be preserved if X-Forwarded-Uri not present
53+
req = httptest.NewRequest("POST", "http://should-use-x-forwarded.com/should-not?ignore=me", nil)
54+
req.Header.Add("X-Forwarded-Method", "GET")
55+
req.Header.Add("X-Forwarded-Proto", "https")
56+
req.Header.Add("X-Forwarded-Host", "example.com")
57+
NewServer().RootHandler(httptest.NewRecorder(), req)
58+
59+
assert.Equal("GET", req.Method, "x-forwarded-method should be read into request")
60+
assert.Equal("example.com", req.Host, "x-forwarded-host should be read into request")
61+
assert.Equal("/should-not", req.URL.Path, "request url should be preserved if x-forwarded-uri not present")
62+
assert.Equal("/should-not?ignore=me", req.URL.RequestURI(), "request url should be preserved if x-forwarded-uri not present")
63+
}
64+
3465
func TestServerAuthHandlerInvalid(t *testing.T) {
3566
assert := assert.New(t)
3667
config = newDefaultConfig()
@@ -90,10 +121,10 @@ func TestServerAuthHandlerExpired(t *testing.T) {
90121
config.Domains = []string{"test.com"}
91122

92123
// Should redirect expired cookie
93-
req := newDefaultHttpRequest("/foo")
124+
req := newHTTPRequest("GET", "http://example.com/foo")
94125
c := MakeCookie(req, "[email protected]")
95126
res, _ := doHttpRequest(req, c)
96-
assert.Equal(307, res.StatusCode, "request with expired cookie should be redirected")
127+
require.Equal(t, 307, res.StatusCode, "request with expired cookie should be redirected")
97128

98129
// Check for CSRF cookie
99130
var cookie *http.Cookie
@@ -116,7 +147,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
116147
config = newDefaultConfig()
117148

118149
// Should allow valid request email
119-
req := newDefaultHttpRequest("/foo")
150+
req := newHTTPRequest("GET", "http://example.com/foo")
120151
c := MakeCookie(req, "[email protected]")
121152
config.Domains = []string{}
122153

@@ -131,6 +162,7 @@ func TestServerAuthHandlerValid(t *testing.T) {
131162

132163
func TestServerAuthCallback(t *testing.T) {
133164
assert := assert.New(t)
165+
require := require.New(t)
134166
config = newDefaultConfig()
135167

136168
// Setup OAuth server
@@ -148,27 +180,28 @@ func TestServerAuthCallback(t *testing.T) {
148180
}
149181

150182
// Should pass auth response request to callback
151-
req := newDefaultHttpRequest("/_oauth")
183+
req := newHTTPRequest("GET", "http://example.com/_oauth")
152184
res, _ := doHttpRequest(req, nil)
153185
assert.Equal(401, res.StatusCode, "auth callback without cookie shouldn't be authorised")
154186

155187
// Should catch invalid csrf cookie
156-
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:http://redirect")
188+
nonce := "12345678901234567890123456789012"
189+
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":http://redirect")
157190
c := MakeCSRFCookie(req, "nononononononononononononononono")
158191
res, _ = doHttpRequest(req, c)
159192
assert.Equal(401, res.StatusCode, "auth callback with invalid cookie shouldn't be authorised")
160193

161194
// Should catch invalid provider cookie
162-
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:invalid:http://redirect")
163-
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
195+
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":invalid:http://redirect")
196+
c = MakeCSRFCookie(req, nonce)
164197
res, _ = doHttpRequest(req, c)
165198
assert.Equal(401, res.StatusCode, "auth callback with invalid provider shouldn't be authorised")
166199

167200
// Should redirect valid request
168-
req = newDefaultHttpRequest("/_oauth?state=12345678901234567890123456789012:google:http://redirect")
169-
c = MakeCSRFCookie(req, "12345678901234567890123456789012")
201+
req = newHTTPRequest("GET", "http://example.com/_oauth?state="+nonce+":google:http://redirect")
202+
c = MakeCSRFCookie(req, nonce)
170203
res, _ = doHttpRequest(req, c)
171-
assert.Equal(307, res.StatusCode, "valid auth callback should be allowed")
204+
require.Equal(307, res.StatusCode, "valid auth callback should be allowed")
172205

173206
fwd, _ := res.Location()
174207
assert.Equal("http", fwd.Scheme, "valid request should be redirected to return url")
@@ -360,17 +393,17 @@ func TestServerRouteHost(t *testing.T) {
360393
}
361394

362395
// Should block any request
363-
req := newHttpRequest("GET", "https://example.com/", "/")
396+
req := newHTTPRequest("GET", "https://example.com/")
364397
res, _ := doHttpRequest(req, nil)
365398
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
366399

367400
// Should allow matching request
368-
req = newHttpRequest("GET", "https://api.example.com/", "/")
401+
req = newHTTPRequest("GET", "https://api.example.com/")
369402
res, _ = doHttpRequest(req, nil)
370403
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
371404

372405
// Should allow matching request
373-
req = newHttpRequest("GET", "https://sub8.example.com/", "/")
406+
req = newHTTPRequest("GET", "https://sub8.example.com/")
374407
res, _ = doHttpRequest(req, nil)
375408
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
376409
}
@@ -386,12 +419,12 @@ func TestServerRouteMethod(t *testing.T) {
386419
}
387420

388421
// Should block any request
389-
req := newHttpRequest("GET", "https://example.com/", "/")
422+
req := newHTTPRequest("GET", "https://example.com/")
390423
res, _ := doHttpRequest(req, nil)
391424
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
392425

393426
// Should allow matching request
394-
req = newHttpRequest("PUT", "https://example.com/", "/")
427+
req = newHTTPRequest("PUT", "https://example.com/")
395428
res, _ = doHttpRequest(req, nil)
396429
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
397430
}
@@ -441,12 +474,12 @@ func TestServerRouteQuery(t *testing.T) {
441474
}
442475

443476
// Should block any request
444-
req := newHttpRequest("GET", "https://example.com/", "/?q=no")
477+
req := newHTTPRequest("GET", "https://example.com/?q=no")
445478
res, _ := doHttpRequest(req, nil)
446479
assert.Equal(307, res.StatusCode, "request not matching any rule should require auth")
447480

448481
// Should allow matching request
449-
req = newHttpRequest("GET", "https://api.example.com/", "/?q=test123")
482+
req = newHTTPRequest("GET", "https://api.example.com/?q=test123")
450483
res, _ = doHttpRequest(req, nil)
451484
assert.Equal(200, res.StatusCode, "request matching allow rule should be allowed")
452485
}
@@ -531,16 +564,17 @@ func newDefaultConfig() *Config {
531564
return config
532565
}
533566

567+
// TODO: replace with newHTTPRequest("GET", "http://example.com/"+uri)
534568
func newDefaultHttpRequest(uri string) *http.Request {
535-
return newHttpRequest("", "http://example.com/", uri)
569+
return newHTTPRequest("GET", "http://example.com"+uri)
536570
}
537571

538-
func newHttpRequest(method, dest, uri string) *http.Request {
539-
r := httptest.NewRequest("", "http://should-use-x-forwarded.com", nil)
540-
p, _ := url.Parse(dest)
572+
func newHTTPRequest(method, target string) *http.Request {
573+
u, _ := url.Parse(target)
574+
r := httptest.NewRequest(method, target, nil)
541575
r.Header.Add("X-Forwarded-Method", method)
542-
r.Header.Add("X-Forwarded-Proto", p.Scheme)
543-
r.Header.Add("X-Forwarded-Host", p.Host)
544-
r.Header.Add("X-Forwarded-Uri", uri)
576+
r.Header.Add("X-Forwarded-Proto", u.Scheme)
577+
r.Header.Add("X-Forwarded-Host", u.Host)
578+
r.Header.Add("X-Forwarded-Uri", u.RequestURI())
545579
return r
546580
}

0 commit comments

Comments
 (0)