diff --git a/docs/middleware/limiter.md b/docs/middleware/limiter.md index a101fcf4b1d..efcdcf9e918 100644 --- a/docs/middleware/limiter.md +++ b/docs/middleware/limiter.md @@ -53,6 +53,13 @@ app.Use(limiter.New(limiter.Config{ return 20 }, Expiration: 30 * time.Second, + ExpirationFunc: func(c fiber.Ctx) time.Duration { + // Use longer expiration for sensitive endpoints + if c.Path() == "/login" { + return 60 * time.Second + } + return 30 * time.Second + }, KeyGenerator: func(c fiber.Ctx) string { return c.Get("x-forwarded-for") }, @@ -99,6 +106,21 @@ app.Use(limiter.New(limiter.Config{ })) ``` +## Dynamic expiration + +You can also calculate the expiration dynamically using the `ExpirationFunc` parameter. It receives the request context and allows you to set a different expiration window for each request. + +Example: + +```go +app.Use(limiter.New(limiter.Config{ + Max: 20, + ExpirationFunc: func(c fiber.Ctx) time.Duration { + return getExpirationForRoute(c.Path()) + }, +})) +``` + ## Config | Property | Type | Description | Default | @@ -108,6 +130,7 @@ app.Use(limiter.New(limiter.Config{ | MaxFunc | `func(fiber.Ctx) int` | Function that calculates the maximum number of recent connections within `Expiration` seconds before sending a 429 response. | A function that returns `cfg.Max` | | KeyGenerator | `func(fiber.Ctx) string` | Function to generate custom keys; uses `c.IP()` by default. | A function using `c.IP()` as the default | | Expiration | `time.Duration` | Duration to keep request records in memory. | 1 * time.Minute | +| ExpirationFunc | `func(fiber.Ctx) time.Duration` | Function that calculates the expiration duration dynamically. | A function that returns `cfg.Expiration` | | LimitReached | `fiber.Handler` | Called when a request exceeds the limit. | A function sending a 429 response | | SkipFailedRequests | `bool` | When set to `true`, requests with status code ≥ 400 aren't counted. | false | | SkipSuccessfulRequests | `bool` | When set to `true`, requests with status code < 400 aren't counted. | false | @@ -129,6 +152,7 @@ var ConfigDefault = Config{ return 5 }, Expiration: 1 * time.Minute, + // ExpirationFunc defaults to nil and is set dynamically to return cfg.Expiration KeyGenerator: func(c fiber.Ctx) string { return c.IP() }, diff --git a/middleware/limiter/config.go b/middleware/limiter/config.go index 9921c61654b..d44cd0c03b1 100644 --- a/middleware/limiter/config.go +++ b/middleware/limiter/config.go @@ -31,6 +31,11 @@ type Config struct { // } MaxFunc func(c fiber.Ctx) int + // A function to dynamically calculate the expiration time for rate limiter entries + // + // Default: A function that returns the static `Expiration` value from the config. + ExpirationFunc func(c fiber.Ctx) time.Duration + // KeyGenerator allows you to generate custom keys, by default c.IP() is used // // Default: func(c fiber.Ctx) string { @@ -83,6 +88,8 @@ var ConfigDefault = Config{ MaxFunc: func(_ fiber.Ctx) int { return defaultLimiterMax }, + // Note: ExpirationFunc is intentionally nil here so that configDefault() + // can create a proper closure that references the configured Expiration value. KeyGenerator: func(c fiber.Ctx) string { return c.IP() }, @@ -98,14 +105,14 @@ var ConfigDefault = Config{ // Helper function to set default values func configDefault(config ...Config) Config { - // Return default config if nothing provided + // Use default config if nothing provided + var cfg Config if len(config) < 1 { - return ConfigDefault + cfg = ConfigDefault + } else { + cfg = config[0] } - // Override default config - cfg := config[0] - // Set default values if cfg.Next == nil { cfg.Next = ConfigDefault.Next @@ -130,5 +137,10 @@ func configDefault(config ...Config) Config { return cfg.Max } } + if cfg.ExpirationFunc == nil { + cfg.ExpirationFunc = func(_ fiber.Ctx) time.Duration { + return cfg.Expiration + } + } return cfg } diff --git a/middleware/limiter/limiter_fixed.go b/middleware/limiter/limiter_fixed.go index 3466921a50c..2877c7778f5 100644 --- a/middleware/limiter/limiter_fixed.go +++ b/middleware/limiter/limiter_fixed.go @@ -19,11 +19,8 @@ func (FixedWindow) New(cfg *Config) fiber.Handler { cfg = &defaultCfg } - var ( - // Limiter variables - mux = &sync.RWMutex{} - expiration = uint64(cfg.Expiration.Seconds()) - ) + // Limiter variables + mux := &sync.RWMutex{} // Create manager to simplify storage operations ( see manager.go ) manager := newManager(cfg.Storage, !cfg.DisableValueRedaction) @@ -41,6 +38,13 @@ func (FixedWindow) New(cfg *Config) fiber.Handler { return c.Next() } + // Generate expiration from generator + expirationDuration := cfg.ExpirationFunc(c) + if expirationDuration <= 0 { + expirationDuration = ConfigDefault.Expiration + } + expiration := uint64(expirationDuration.Seconds()) + // Get key from request key := cfg.KeyGenerator(c) @@ -78,7 +82,7 @@ func (FixedWindow) New(cfg *Config) fiber.Handler { remaining := maxRequests - e.currHits // Update storage - if setErr := manager.set(reqCtx, key, e, cfg.Expiration); setErr != nil { + if setErr := manager.set(reqCtx, key, e, expirationDuration); setErr != nil { mux.Unlock() return fmt.Errorf("limiter: failed to persist state: %w", setErr) } @@ -118,7 +122,7 @@ func (FixedWindow) New(cfg *Config) fiber.Handler { e = entry e.currHits-- remaining++ - if setErr := manager.set(reqCtx, key, e, cfg.Expiration); setErr != nil { + if setErr := manager.set(reqCtx, key, e, expirationDuration); setErr != nil { mux.Unlock() return fmt.Errorf("limiter: failed to persist state: %w", setErr) } diff --git a/middleware/limiter/limiter_sliding.go b/middleware/limiter/limiter_sliding.go index 98b3bf8cce5..aadc3f8a006 100644 --- a/middleware/limiter/limiter_sliding.go +++ b/middleware/limiter/limiter_sliding.go @@ -21,11 +21,8 @@ func (SlidingWindow) New(cfg *Config) fiber.Handler { cfg = &defaultCfg } - var ( - // Limiter variables - mux = &sync.RWMutex{} - expiration = uint64(cfg.Expiration.Seconds()) - ) + // Limiter variables + mux := &sync.RWMutex{} // Create manager to simplify storage operations ( see manager.go ) manager := newManager(cfg.Storage, !cfg.DisableValueRedaction) @@ -43,6 +40,13 @@ func (SlidingWindow) New(cfg *Config) fiber.Handler { return c.Next() } + // Generate expiration from generator + expirationDuration := cfg.ExpirationFunc(c) + if expirationDuration <= 0 { + expirationDuration = ConfigDefault.Expiration + } + expiration := uint64(expirationDuration.Seconds()) + // Get key from request key := cfg.KeyGenerator(c) diff --git a/middleware/limiter/limiter_test.go b/middleware/limiter/limiter_test.go index e514595d90e..cbc08bfe653 100644 --- a/middleware/limiter/limiter_test.go +++ b/middleware/limiter/limiter_test.go @@ -11,11 +11,12 @@ import ( "testing" "time" - "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/internal/storage/memory" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/internal/storage/memory" ) type failingLimiterStorage struct { @@ -32,6 +33,30 @@ func newFailingLimiterStorage() *failingLimiterStorage { } } +// countingFailStorage fails set operations after a specified number of successful calls +type countingFailStorage struct { + *failingLimiterStorage + setFailErr error + setCount int + failAfterN int +} + +func newCountingFailStorage(failAfterN int, err error) *countingFailStorage { + return &countingFailStorage{ + failingLimiterStorage: newFailingLimiterStorage(), + failAfterN: failAfterN, + setFailErr: err, + } +} + +func (s *countingFailStorage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { + s.setCount++ + if s.setCount > s.failAfterN { + return s.setFailErr + } + return s.failingLimiterStorage.SetWithContext(ctx, key, val, exp) +} + type contextRecord struct { key string value string @@ -322,6 +347,37 @@ func TestLimiterFixedStorageSetErrorDisableRedaction(t *testing.T) { require.NotContains(t, captured.Error(), "[redacted]") } +func TestLimiterFixedStorageSetErrorOnSkipSuccessfulRequests(t *testing.T) { + t.Parallel() + + storage := newCountingFailStorage(1, errors.New("second set failed")) + + var captured error + app := fiber.New(fiber.Config{ + ErrorHandler: func(c fiber.Ctx, err error) error { + captured = err + return c.Status(fiber.StatusInternalServerError).SendString("storage failure") + }, + }) + + app.Use(New(Config{ + Storage: storage, + Max: 10, + Expiration: time.Second, + SkipSuccessfulRequests: true, + KeyGenerator: func(fiber.Ctx) string { return testLimiterClientKey }, + })) + app.Get("/", func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) + require.Error(t, captured) + require.ErrorContains(t, captured, "limiter: failed to persist state") +} + func TestLimiterSlidingPropagatesRequestContextToStorage(t *testing.T) { t.Parallel() @@ -571,6 +627,173 @@ func Test_Limiter_With_Max_Func(t *testing.T) { require.Equal(t, 200, resp.StatusCode) } +// go test -run Test_Limiter_Fixed_ExpirationFuncOverridesStaticExpiration -race -v +func Test_Limiter_Fixed_ExpirationFuncOverridesStaticExpiration(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + Max: 2, + Expiration: 10 * time.Second, + ExpirationFunc: func(_ fiber.Ctx) time.Duration { return 2 * time.Second }, + LimiterMiddleware: FixedWindow{}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode) + + time.Sleep(3 * time.Second) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +// go test -run Test_Limiter_Sliding_ExpirationFuncOverridesStaticExpiration -race -v +func Test_Limiter_Sliding_ExpirationFuncOverridesStaticExpiration(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + Max: 2, + Expiration: 10 * time.Second, + ExpirationFunc: func(_ fiber.Ctx) time.Duration { return 2 * time.Second }, + LimiterMiddleware: SlidingWindow{}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode) + + // Sliding window needs ~2x expiration to fully reset (considers previous window) + time.Sleep(4*time.Second + 500*time.Millisecond) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) +} + +// go test -run Test_Limiter_Fixed_ExpirationFunc_FallbackOnZeroDuration -race -v +func Test_Limiter_Fixed_ExpirationFunc_FallbackOnZeroDuration(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + Max: 1, + ExpirationFunc: func(_ fiber.Ctx) time.Duration { return 0 }, + LimiterMiddleware: FixedWindow{}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode) +} + +// go test -run Test_Limiter_Fixed_ExpirationFunc_FallbackOnNegativeDuration -race -v +func Test_Limiter_Fixed_ExpirationFunc_FallbackOnNegativeDuration(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + Max: 1, + ExpirationFunc: func(_ fiber.Ctx) time.Duration { return -1 * time.Second }, + LimiterMiddleware: FixedWindow{}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode) +} + +// go test -run Test_Limiter_Sliding_ExpirationFunc_FallbackOnZeroDuration -race -v +func Test_Limiter_Sliding_ExpirationFunc_FallbackOnZeroDuration(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + Max: 1, + ExpirationFunc: func(_ fiber.Ctx) time.Duration { return 0 }, + LimiterMiddleware: SlidingWindow{}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode) +} + +// go test -run Test_Limiter_Sliding_ExpirationFunc_FallbackOnNegativeDuration -race -v +func Test_Limiter_Sliding_ExpirationFunc_FallbackOnNegativeDuration(t *testing.T) { + t.Parallel() + app := fiber.New() + + app.Use(New(Config{ + Max: 1, + ExpirationFunc: func(_ fiber.Ctx) time.Duration { return -1 * time.Second }, + LimiterMiddleware: SlidingWindow{}, + })) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode) + + resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) + require.NoError(t, err) + require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode) +} + // go test -run Test_Limiter_Concurrency_Store -race -v func Test_Limiter_Concurrency_Store(t *testing.T) { t.Parallel()