diff --git a/README.md b/README.md index 19ae038..80e5612 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,10 @@ # gin-rate-limit -gin-rate-limit is a rate limiter for the gin framework. By default, it can -only store rate limit info in memory and with redis. If you want to store it somewhere else you can make your own store -or use third party stores. The library is new so there are no third party stores yet, so I would appreciate if someone -could make one. +gin-rate-limit is a rate limiter for the gin framework. By default, it +can only store rate limit info in memory and with redis. If you want to store it somewhere else you can make your own +store or use third party stores. The library is new so there are no third party stores yet, so I would appreciate if +someone could make one. Install @@ -44,10 +44,13 @@ func main() { RedisClient: redis.NewClient(&redis.Options{ Addr: "localhost:7680", }), - Rate: time.Second, + Rate: time.Second, Limit: 5, + }) + mw := ratelimit.RateLimiter(store, &ratelimit.Options{ + ErrorHanlder: errorHandler, + KeyFunc: keyfunc, }) - mw := ratelimit.RateLimiter(keyFunc, errorHandler, store) server.GET("/", mw, func(c *gin.Context) { c.String(200, "Hello World") }) @@ -80,10 +83,13 @@ func main() { server := gin.Default() // This makes it so each ip can only make 5 requests per second store := ratelimit.InMemoryStore(&ratelimit.InMemoryOptions{ - Rate: time.Second, + Rate: time.Second, Limit: 5, - }) - mw := ratelimit.RateLimiter(keyFunc, errorHandler, store) + }) + mw := ratelimit.RateLimiter(store, &ratelimit.Options{ + ErrorHanlder: errorHandler, + KeyFunc: keyfunc, + }) server.GET("/", mw, func(c *gin.Context) { c.String(200, "Hello World") }) @@ -93,6 +99,12 @@ func main() {
+# Custom Stores + +Unlike most rate limit libraries that support third party stores, gin-rate-limit only requires your store to have one +method and lets you rate limit however you want. The default stores may not have exactly what you need, but it is easy +to take one of the default stores and modify them. If you think your store could be useful to other people push it to GitHub. + Custom Store Example ```go @@ -115,10 +127,4 @@ func (s *CustomStore) Limit(key string) (bool, time.Duration) { } return false, remaining } - -// Your store must have a method called Skip that takes a *gin.Context and returns a bool -func (s *CustomStore) Skip(c *gin.Context) bool { - // return true if you dont want this request to count toward the users rate limit - return false -} ``` \ No newline at end of file diff --git a/gin_rate_limit.go b/gin_rate_limit.go index 217d85c..ed307f1 100644 --- a/gin_rate_limit.go +++ b/gin_rate_limit.go @@ -8,24 +8,37 @@ import ( type Store interface { // Limit takes in a key and should return whether that key is allowed to make another request Limit(key string) (bool, time.Duration) - // Skip takes in a *gin.Context and should return whether the rate limiting should be skipped for this request - Skip(c *gin.Context) bool +} + +type Options struct { + ErrorHandler func(*gin.Context, time.Duration) + KeyFunc func(*gin.Context) string + // a function that returns true if the request should not count toward the rate limit + Skip func(*gin.Context) bool } // RateLimiter is a function to get gin.HandlerFunc -// keyFunc: takes in *gin.Context and return a string -// errorHandler: takes in *gin.Context and time.Duration -// store: Store -func RateLimiter(keyFunc func(c *gin.Context) string, errorHandler func(c *gin.Context, remaining time.Duration), s Store) gin.HandlerFunc { +func RateLimiter(s Store, options *Options) gin.HandlerFunc { + if options.ErrorHandler == nil { + options.ErrorHandler = func(c *gin.Context, remaining time.Duration) { + c.Header("X-Rate-Limit-Reset", remaining.String()) + c.String(429, "Too many requests") + } + } + if options.KeyFunc == nil { + options.KeyFunc = func(c *gin.Context) string { + return c.ClientIP() + c.FullPath() + } + } return func(c *gin.Context) { - if s.Skip(c) { + if options.Skip != nil && options.Skip(c) { c.Next() return } - key := keyFunc(c) + key := options.KeyFunc(c) limited, remaining := s.Limit(key) if limited { - errorHandler(c, remaining) + options.ErrorHandler(c, remaining) c.Abort() } else { c.Next() diff --git a/in_memory.go b/in_memory.go index 48f25af..6e031c8 100644 --- a/in_memory.go +++ b/in_memory.go @@ -1,7 +1,6 @@ package ratelimit import ( - "github.com/gin-gonic/gin" "sync" "time" ) @@ -27,7 +26,6 @@ type inMemoryStoreType struct { rate int64 limit uint data *sync.Map - skip func(c *gin.Context) bool } func (s *inMemoryStoreType) Limit(key string) (bool, time.Duration) { @@ -51,26 +49,16 @@ func (s *inMemoryStoreType) Limit(key string) (bool, time.Duration) { return false, time.Duration(0) } -func (s *inMemoryStoreType) Skip(c *gin.Context) bool { - if s.skip != nil { - return s.skip(c) - } else { - return false - } -} - type InMemoryOptions struct { // the user can make Limit amount of requests every Rate Rate time.Duration // the amount of requests that can be made every Rate Limit uint - // takes in a *gin.Context and should return whether the rate limiting should be skipped for this request - Skip func(c *gin.Context) bool } func InMemoryStore(options *InMemoryOptions) Store { data := &sync.Map{} - store := inMemoryStoreType{int64(options.Rate.Seconds()), options.Limit, data, options.Skip} + store := inMemoryStoreType{int64(options.Rate.Seconds()), options.Limit, data} go clearInBackground(data, store.rate) return &store } diff --git a/redis.go b/redis.go index 192c48a..681f706 100644 --- a/redis.go +++ b/redis.go @@ -63,21 +63,11 @@ func (s *redisStoreType) Limit(key string) (bool, time.Duration) { return false, time.Duration(0) } -func (s *redisStoreType) Skip(c *gin.Context) bool { - if s.skip != nil { - return s.skip(c) - } else { - return false - } -} - type RedisOptions struct { // the user can make Limit amount of requests every Rate Rate time.Duration // the amount of requests that can be made every Rate - Limit uint - // takes in a *gin.Context and should return whether the rate limiting should be skipped for this request - Skip func(c *gin.Context) bool + Limit uint RedisClient redis.UniversalClient // should gin-rate-limit panic when there is an error with redis PanicOnErr bool @@ -90,6 +80,5 @@ func RedisStore(options *RedisOptions) Store { limit: options.Limit, ctx: context.TODO(), panicOnErr: options.PanicOnErr, - skip: options.Skip, } }