diff --git a/lib/service/service.go b/lib/service/service.go index 8faebdc01a9fd..e162d3a62ce51 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4931,12 +4931,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { webServer, err = web.NewServer(web.ServerConfig{ Server: &http.Server{ - Handler: utils.ChainHTTPMiddlewares( - webHandler, - makeXForwardedForMiddleware(cfg), - limiter.MakeMiddleware(proxyLimiter), - httplib.MakeTracingMiddleware(teleport.ComponentProxy), - ), + Handler: wrapWebHandlerWithMiddlewares(webHandler, proxyLimiter, cfg), // Note: read/write timeouts *should not* be set here because it // will break some application access use-cases. ReadHeaderTimeout: defaults.ReadHeadersTimeout, @@ -6903,6 +6898,20 @@ func makeXForwardedForMiddleware(cfg *servicecfg.Config) utils.HTTPMiddleware { return utils.NoopHTTPMiddleware } +func wrapWebHandlerWithMiddlewares(webHandler http.Handler, proxyLimiter *limiter.Limiter, cfg *servicecfg.Config) http.Handler { + // Handler wrapped with ChainHTTPMiddlewares will invoke the middlewares in + // the reverse order of their specification. + // + // X-Forwarded-For middleware must be applied before the limiter so that the + // limiter operates on the real client IP instead of the load balancer IP. + return utils.ChainHTTPMiddlewares( + webHandler, + limiter.MakeMiddleware(proxyLimiter), + makeXForwardedForMiddleware(cfg), + httplib.MakeTracingMiddleware(teleport.ComponentProxy), + ) +} + // makeApplicationCORS converts a servicecfg.CORS to a types.CORS. func makeApplicationCORS(c *servicecfg.CORS) *types.CORSPolicy { if c == nil { diff --git a/lib/service/service_test.go b/lib/service/service_test.go index 63e8ec9c9ff3d..3ce05ca82b033 100644 --- a/lib/service/service_test.go +++ b/lib/service/service_test.go @@ -27,6 +27,7 @@ import ( "log/slog" "net" "net/http" + "net/http/httptest" "net/url" "os" "path/filepath" @@ -2103,3 +2104,33 @@ func makeTempDir(t *testing.T) string { t.Cleanup(func() { os.RemoveAll(tempDir) }) return tempDir } + +// Test_wrapWebHandlerWithMiddlewares_ip tests that real IP from X-Forwarded-For +// is used for limiter. +func Test_wrapWebHandlerWithMiddlewares_ip(t *testing.T) { + proxyLimiter, err := limiter.NewLimiter(limiter.Config{ + MaxConnections: 1, + }) + require.NoError(t, err) + + // Set the request from a local ip but also with a real IP in + // X-Forwarded-For. + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "10.0.0.10:100" + req.Header.Set("X-Forwarded-For", "22.22.22.22:222") + cfg := &servicecfg.Config{ + Proxy: servicecfg.ProxyConfig{ + TrustXForwardedFor: true, + }, + } + + // Register one token to let the limiter blocks the request. Note that token + // is counted without the port. + release, err := proxyLimiter.RegisterRequestAndConnection("22.22.22.22") + require.NoError(t, err) + t.Cleanup(release) + + recorder := httptest.NewRecorder() + wrapWebHandlerWithMiddlewares(http.NotFoundHandler(), proxyLimiter, cfg).ServeHTTP(recorder, req) + require.Equal(t, http.StatusTooManyRequests, recorder.Code) +}