diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 152205ed65631..2311a171959c2 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -106,6 +106,7 @@ type Handler struct { auth *sessionCache sessionStreamPollPeriod time.Duration clock clockwork.Clock + limiter *limiter.RateLimiter // sshPort specifies the SSH proxy port extracted // from configuration sshPort string @@ -303,9 +304,9 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { } h.sshPort = sshPort - // challengeLimiter is used to limit unauthenticated challenge generation for - // passwordless. - challengeLimiter, err := limiter.NewRateLimiter(limiter.Config{ + // rateLimiter is used to limit unauthenticated challenge generation for + // passwordless and for unauthenticated metrics. + h.limiter, err = limiter.NewRateLimiter(limiter.Config{ Rates: []limiter.Rate{ { Period: defaults.LimiterPasswordlessPeriod, @@ -323,7 +324,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { if cfg.MinimalReverseTunnelRoutesOnly { h.bindMinimalEndpoints() } else { - h.bindDefaultEndpoints(challengeLimiter) + h.bindDefaultEndpoints() } // if Web UI is enabled, check the assets dir: @@ -462,7 +463,7 @@ func (h *Handler) bindMinimalEndpoints() { } // bindDefaultEndpoints binds the default endpoints for the web API. -func (h *Handler) bindDefaultEndpoints(challengeLimiter *limiter.RateLimiter) { +func (h *Handler) bindDefaultEndpoints() { h.bindMinimalEndpoints() // ping endpoint is used to check if the server is up. the /webapi/ping @@ -487,17 +488,19 @@ func (h *Handler) bindDefaultEndpoints(challengeLimiter *limiter.RateLimiter) { h.GET("/webapi/scripts/desktop-access/install-ad-cs.ps1", httplib.MakeHandler(h.desktopAccessScriptInstallADCSHandle)) h.GET("/webapi/scripts/desktop-access/configure/:token/configure-ad.ps1", httplib.MakeHandler(h.desktopAccessScriptConfigureHandle)) - // DELETE IN: 5.1.0 - // - // Migrated this endpoint to /webapi/sessions/web below. - h.POST("/webapi/sessions", httplib.WithCSRFProtection(h.createWebSession)) - - // Web sessions - h.POST("/webapi/sessions/web", httplib.WithCSRFProtection(h.createWebSession)) + // App sessions h.POST("/webapi/sessions/app", h.WithAuth(h.createAppSession)) - h.DELETE("/webapi/sessions", h.WithAuth(h.deleteSession)) - h.POST("/webapi/sessions/renew", h.WithAuth(h.renewSession)) + // DELETE IN 13, deprecated/unused web sessions routes (avatus) + // https://github.com/gravitational/teleport/pull/19892 + h.POST("/webapi/sessions", httplib.WithCSRFProtection(h.WithLimiterHandlerFunc(h.createWebSession))) + h.DELETE("/webapi/sessions", h.WithAuth(h.deleteWebSession)) + h.POST("/webapi/sessions/renew", h.WithAuth(h.renewWebSession)) + + // Web sessions + h.POST("/webapi/sessions/web", httplib.WithCSRFProtection(h.WithLimiterHandlerFunc(h.createWebSession))) + h.DELETE("/webapi/sessions/web", h.WithAuth(h.deleteWebSession)) + h.POST("/webapi/sessions/web/renew", h.WithAuth(h.renewWebSession)) h.POST("/webapi/users", h.WithAuth(h.createUserHandle)) h.PUT("/webapi/users", h.WithAuth(h.updateUserHandle)) h.GET("/webapi/users", h.WithAuth(h.getUsersHandle)) @@ -574,21 +577,21 @@ func (h *Handler) bindDefaultEndpoints(challengeLimiter *limiter.RateLimiter) { // OIDC related callback handlers h.GET("/webapi/oidc/login/web", h.WithRedirect(h.oidcLoginWeb)) h.GET("/webapi/oidc/callback", h.WithMetaRedirect(h.oidcCallback)) - h.POST("/webapi/oidc/login/console", httplib.MakeHandler(h.oidcLoginConsole)) + h.POST("/webapi/oidc/login/console", h.WithLimiter(h.oidcLoginConsole)) // SAML 2.0 handlers h.POST("/webapi/saml/acs", h.WithMetaRedirect(h.samlACS)) h.POST("/webapi/saml/acs/:connector", h.WithMetaRedirect(h.samlACS)) h.GET("/webapi/saml/sso", h.WithMetaRedirect(h.samlSSO)) - h.POST("/webapi/saml/login/console", httplib.MakeHandler(h.samlSSOConsole)) + h.POST("/webapi/saml/login/console", h.WithLimiter(h.samlSSOConsole)) // Github connector handlers h.GET("/webapi/github/login/web", h.WithRedirect(h.githubLoginWeb)) h.GET("/webapi/github/callback", h.WithMetaRedirect(h.githubCallback)) - h.POST("/webapi/github/login/console", httplib.MakeHandler(h.githubLoginConsole)) + h.POST("/webapi/github/login/console", h.WithLimiter(h.githubLoginConsole)) // MFA public endpoints. - h.POST("/webapi/mfa/login/begin", h.withLimiter(challengeLimiter, h.mfaLoginBegin)) + h.POST("/webapi/mfa/login/begin", h.WithLimiter(h.mfaLoginBegin)) h.POST("/webapi/mfa/login/finish", httplib.MakeHandler(h.mfaLoginFinish)) h.POST("/webapi/mfa/login/finishsession", httplib.MakeHandler(h.mfaLoginFinishSession)) h.DELETE("/webapi/mfa/token/:token/devices/:devicename", httplib.MakeHandler(h.deleteMFADeviceWithTokenHandle)) @@ -1743,14 +1746,14 @@ func clientMetaFromReq(r *http.Request) *auth.ForwardedClientMetadata { } } -// deleteSession is called to sign out user +// deleteWebSession is called to sign out user // // DELETE /v1/webapi/sessions/:sid // // Response: // // {"message": "ok"} -func (h *Handler) deleteSession(w http.ResponseWriter, r *http.Request, _ httprouter.Params, ctx *SessionContext) (interface{}, error) { +func (h *Handler) deleteWebSession(w http.ResponseWriter, r *http.Request, _ httprouter.Params, ctx *SessionContext) (interface{}, error) { err := h.logout(r.Context(), w, ctx) if err != nil { return nil, trace.Wrap(err) @@ -1778,14 +1781,14 @@ type renewSessionRequest struct { ReloadUser bool `json:"reloadUser"` } -// renewSession updates this existing session with a new session. +// renewWebSession updates this existing session with a new session. // // Depending on request fields sent in for extension, the new session creation can vary depending on: // - AccessRequestID (opt): appends roles approved from access request to currently assigned roles or, // - Switchback (opt): roles stacked with assuming approved access requests, will revert to user's default roles // - ReloadUser (opt): similar to default but updates user related data (e.g login traits) by retrieving it from the backend // - default (none set): create new session with currently assigned roles -func (h *Handler) renewSession(w http.ResponseWriter, r *http.Request, params httprouter.Params, ctx *SessionContext) (interface{}, error) { +func (h *Handler) renewWebSession(w http.ResponseWriter, r *http.Request, params httprouter.Params, ctx *SessionContext) (interface{}, error) { req := renewSessionRequest{} if err := httplib.ReadJSON(r, &req); err != nil { return nil, trace.Wrap(err) @@ -3064,10 +3067,22 @@ func (h *Handler) WithAuth(fn ContextHandler) httprouter.Handle { }) } -// withLimiter adds IP-based rate limiting to fn. -func (h *Handler) withLimiter(l *limiter.RateLimiter, fn httplib.HandlerFunc) httprouter.Handle { +// WithLimiter adds IP-based rate limiting to fn. +func (h *Handler) WithLimiter(fn httplib.HandlerFunc) httprouter.Handle { return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { - err := l.RegisterRequest(r.RemoteAddr, nil /* customRate */) + return h.WithLimiterHandlerFunc(fn)(w, r, p) + }) +} + +// WithLimiterHandlerFunc adds IP-based rate limiting to a HandlerFunc. This +// should be used when you need to nest this inside another HandlerFunc. +func (h *Handler) WithLimiterHandlerFunc(fn httplib.HandlerFunc) httplib.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { + remote, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return nil, trace.Wrap(err) + } + err = h.limiter.RegisterRequest(remote, nil /* customRate */) // MaxRateError doesn't play well with errors.Is, hence the cast. if _, ok := err.(*ratelimit.MaxRateError); ok { return nil, trace.LimitExceeded(err.Error()) @@ -3076,7 +3091,7 @@ func (h *Handler) withLimiter(l *limiter.RateLimiter, fn httplib.HandlerFunc) ht return nil, trace.Wrap(err) } return fn(w, r, p) - }) + } } // AuthenticateRequest authenticates request using combination of a session cookie diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 5223a2145a03b..2bab6915e8cb0 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -57,6 +57,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/julienschmidt/httprouter" lemma_secret "github.com/mailgun/lemma/secret" + "github.com/mailgun/timetools" "github.com/pquerna/otp/totp" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" @@ -462,7 +463,7 @@ func noCache(clt auth.ClientI, cacheName []string) (auth.RemoteProxyAccessPoint, } func (r *authPack) renewSession(ctx context.Context, t *testing.T) *roundtrip.Response { - resp, err := r.clt.PostJSON(ctx, r.clt.Endpoint("webapi", "sessions", "renew"), nil) + resp, err := r.clt.PostJSON(ctx, r.clt.Endpoint("webapi", "sessions", "web", "renew"), nil) require.NoError(t, err) return resp } @@ -760,7 +761,7 @@ func TestWebSessionsCRUD(t *testing.T) { // now delete session _, err = pack.clt.Delete( context.Background(), - pack.clt.Endpoint("webapi", "sessions")) + pack.clt.Endpoint("webapi", "sessions", "web")) require.NoError(t, err) // subsequent requests trying to use this session will fail @@ -912,7 +913,7 @@ func TestWebSessionsBadInput(t *testing.T) { } for i, req := range reqs { t.Run(fmt.Sprintf("tc %v", i), func(t *testing.T) { - _, err := clt.PostJSON(s.ctx, clt.Endpoint("webapi", "sessions"), req) + _, err := clt.PostJSON(s.ctx, clt.Endpoint("webapi", "sessions", "web"), req) require.Error(t, err) require.True(t, trace.IsAccessDenied(err)) }) @@ -1787,7 +1788,7 @@ func TestCloseConnectionsOnLogout(t *testing.T) { _, err = stream.Read(out) require.NoError(t, err) - _, err = pack.clt.Delete(s.ctx, pack.clt.Endpoint("webapi", "sessions")) + _, err = pack.clt.Delete(s.ctx, pack.clt.Endpoint("webapi", "sessions", "web")) require.NoError(t, err) // wait until we timeout or detect that connection has been closed @@ -1886,7 +1887,7 @@ func TestLogin(t *testing.T) { clt := s.client() ua := "test-ua" - req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions"), bytes.NewBuffer(loginReq)) + req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions", "web"), bytes.NewBuffer(loginReq)) require.NoError(t, err) req.Header.Set("User-Agent", ua) @@ -3176,7 +3177,7 @@ func TestApplicationWebSessionsDeletedAfterLogout(t *testing.T) { require.Len(t, sessions, len(applications)) // Logout from Telport. - _, err = pack.clt.Delete(context.Background(), pack.clt.Endpoint("webapi", "sessions")) + _, err = pack.clt.Delete(context.Background(), pack.clt.Endpoint("webapi", "sessions", "web")) require.NoError(t, err) // Check sessions after logout, should be empty. @@ -3921,7 +3922,7 @@ func TestWebSessionsRenewAllowsOldBearerTokenToLinger(t *testing.T) { // now delete session _, err = newPack.clt.Delete( context.Background(), - pack.clt.Endpoint("webapi", "sessions")) + pack.clt.Endpoint("webapi", "sessions", "web")) require.NoError(t, err) // subsequent requests to use this session will fail @@ -4944,7 +4945,7 @@ func (s *WebSuite) login(clt *client.WebClient, cookieToken string, reqToken str if err != nil { return nil, err } - req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions"), bytes.NewBuffer(data)) + req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions", "web"), bytes.NewBuffer(data)) if err != nil { return nil, err } @@ -5558,7 +5559,7 @@ func login(t *testing.T, clt *client.WebClient, cookieToken, reqToken string, re if err != nil { return nil, err } - req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions"), bytes.NewBuffer(data)) + req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions", "web"), bytes.NewBuffer(data)) if err != nil { return nil, err } @@ -5632,7 +5633,7 @@ func TestUserContextWithAccessRequest(t *testing.T) { accessRequestID := accessReq.GetMetadata().Name // Make a request to renew the session with the ID of the access request. - _, err = pack.clt.PostJSON(ctx, pack.clt.Endpoint("webapi", "sessions", "renew"), renewSessionRequest{ + _, err = pack.clt.PostJSON(ctx, pack.clt.Endpoint("webapi", "sessions", "web", "renew"), renewSessionRequest{ AccessRequestID: accessRequestID, }) require.NoError(t, err) @@ -5651,6 +5652,40 @@ func TestUserContextWithAccessRequest(t *testing.T) { require.Equal(t, accessRequestID, userContext.ConsumedAccessRequestID) } +func TestWithLimiterHandlerFunc(t *testing.T) { + const burst = 20 + limiter, err := limiter.NewRateLimiter(limiter.Config{ + Rates: []limiter.Rate{ + { + Period: time.Minute, + Average: 10, + Burst: burst, + }, + }, + Clock: &timetools.FreezedTime{ + CurrentTime: time.Date(2016, 6, 5, 4, 3, 2, 1, time.UTC), + }, + }) + require.NoError(t, err) + h := &Handler{limiter: limiter} + hf := h.WithLimiterHandlerFunc(func(http.ResponseWriter, *http.Request, httprouter.Params) (interface{}, error) { + return nil, nil + }) + + // Verify that a valid burst is allowed. + r := &http.Request{} + for i := 0; i < burst; i++ { + r.RemoteAddr = fmt.Sprintf("127.0.0.1:%v", i) + _, err = hf(nil, r, nil) + require.NoError(t, err, "WithLimiterHandlerFunc failed unexpectedly") + } + + // Verify that exceeding the limit causes errors. + r.RemoteAddr = fmt.Sprintf("127.0.0.1:%v", burst) + _, err = hf(nil, r, nil) + require.True(t, trace.IsLimitExceeded(err), "WithLimiterHandlerFunc returned err = %T, want trace.LimitExceededError", err) +} + // kubeClusterConfig defines the cluster to be created type kubeClusterConfig struct { name string @@ -5895,7 +5930,7 @@ func TestLogout(t *testing.T) { require.Len(t, clusters, 1) // logout from proxy 1 - _, err = pack.clt.Delete(ctx, pack.clt.Endpoint("webapi", "sessions")) + _, err = pack.clt.Delete(ctx, pack.clt.Endpoint("webapi", "sessions", "web")) require.NoError(t, err) // ensure proxy 1 invalidated the session