Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4931,12 +4931,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {

webServer, err = web.NewServer(web.ServerConfig{
Server: &http.Server{
Handler: utils.ChainHTTPMiddlewares(
webHandler,
makeXForwardedForMiddleware(cfg),
limiter.MakeMiddleware(proxyLimiter),
httplib.MakeTracingMiddleware(teleport.ComponentProxy),
),
Handler: wrapWebHandlerWithMiddlewares(webHandler, proxyLimiter, cfg),
// Note: read/write timeouts *should not* be set here because it
// will break some application access use-cases.
ReadHeaderTimeout: defaults.ReadHeadersTimeout,
Expand Down Expand Up @@ -6903,6 +6898,20 @@ func makeXForwardedForMiddleware(cfg *servicecfg.Config) utils.HTTPMiddleware {
return utils.NoopHTTPMiddleware
}

func wrapWebHandlerWithMiddlewares(webHandler http.Handler, proxyLimiter *limiter.Limiter, cfg *servicecfg.Config) http.Handler {
// Handler wrapped with ChainHTTPMiddlewares will invoke the middlewares in
// the reverse order of their specification.
//
// X-Forwarded-For middleware must be applied before the limiter so that the
// limiter operates on the real client IP instead of the load balancer IP.
return utils.ChainHTTPMiddlewares(
webHandler,
limiter.MakeMiddleware(proxyLimiter),
makeXForwardedForMiddleware(cfg),
httplib.MakeTracingMiddleware(teleport.ComponentProxy),
)
}

// makeApplicationCORS converts a servicecfg.CORS to a types.CORS.
func makeApplicationCORS(c *servicecfg.CORS) *types.CORSPolicy {
if c == nil {
Expand Down
31 changes: 31 additions & 0 deletions lib/service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"log/slog"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
Expand Down Expand Up @@ -2103,3 +2104,33 @@ func makeTempDir(t *testing.T) string {
t.Cleanup(func() { os.RemoveAll(tempDir) })
return tempDir
}

// Test_wrapWebHandlerWithMiddlewares_ip tests that real IP from X-Forwarded-For
// is used for limiter.
func Test_wrapWebHandlerWithMiddlewares_ip(t *testing.T) {
proxyLimiter, err := limiter.NewLimiter(limiter.Config{
MaxConnections: 1,
})
require.NoError(t, err)

// Set the request from a local ip but also with a real IP in
// X-Forwarded-For.
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "10.0.0.10:100"
req.Header.Set("X-Forwarded-For", "22.22.22.22:222")
cfg := &servicecfg.Config{
Proxy: servicecfg.ProxyConfig{
TrustXForwardedFor: true,
},
}

// Register one token to let the limiter blocks the request. Note that token
// is counted without the port.
release, err := proxyLimiter.RegisterRequestAndConnection("22.22.22.22")
require.NoError(t, err)
t.Cleanup(release)

recorder := httptest.NewRecorder()
wrapWebHandlerWithMiddlewares(http.NotFoundHandler(), proxyLimiter, cfg).ServeHTTP(recorder, req)
require.Equal(t, http.StatusTooManyRequests, recorder.Code)
}