diff --git a/README.md b/README.md index b040f84..efc4691 100644 --- a/README.md +++ b/README.md @@ -33,8 +33,8 @@ func keyFunc(c *gin.Context) string { return c.ClientIP() } -func errorHandler(c *gin.Context, remaining time.Duration) { - c.String(429, "Too many requests. Try again in "+remaining.String()) +func errorHandler(c *gin.Context, info ratelimit.Info) { + c.String(429, "Too many requests. Try again in "+time.Until(info.ResetTime).String()) } func main() { @@ -75,8 +75,8 @@ func keyFunc(c *gin.Context) string { return c.ClientIP() } -func errorHandler(c *gin.Context, remaining time.Duration) { - c.String(429, "Too many requests. Try again in "+remaining.String()) +func errorHandler(c *gin.Context, info ratelimit.Info) { + c.String(429, "Too many requests. Try again in "+time.Until(info.ResetTime).String()) } func main() { @@ -106,20 +106,26 @@ Custom Store Example package main import ( + "github.com/JGLTechnologies/gin-rate-limit" "github.com/gin-gonic/gin" - "time" ) type CustomStore struct { } -// Your store must have a method called Limit that takes a key and returns a bool, time.Duration -func (s *CustomStore) Limit(key string) (bool, time.Duration) { - // Do your rate limit logic, and return true if the user went over the rate limit, otherwise return false - // Return the amount of time the client needs to wait to make a new request +// Your store must have a method called Limit that takes a key, *gin.Context and returns ratelimit.Info +func (s *CustomStore) Limit(key string, c *gin.Context) Info { if UserWentOverLimit { - return true, remaining + return Info{ + RateLimited: true, + ResetTime: reset, + RemainingHits: 0, + } + } + return Info{ + RateLimited: false, + ResetTime: reset, + RemainingHits: remaining, } - return false, remaining } ``` \ No newline at end of file diff --git a/gin_rate_limit.go b/gin_rate_limit.go index ed307f1..96f01d6 100644 --- a/gin_rate_limit.go +++ b/gin_rate_limit.go @@ -1,44 +1,57 @@ package ratelimit import ( + "fmt" "github.com/gin-gonic/gin" "time" ) +type Info struct { + RateLimited bool + ResetTime time.Time + RemainingHits uint +} + type Store interface { - // Limit takes in a key and should return whether that key is allowed to make another request - Limit(key string) (bool, time.Duration) + // Limit takes in a key and *gin.Context and should return whether that key is allowed to make another request + Limit(key string, c *gin.Context) Info } type Options struct { - ErrorHandler func(*gin.Context, time.Duration) + ErrorHandler func(*gin.Context, Info) KeyFunc func(*gin.Context) string - // a function that returns true if the request should not count toward the rate limit - Skip func(*gin.Context) bool + // a function that lets you check the rate limiting info and modify the response + BeforeResponse func(c *gin.Context, info Info) } // RateLimiter is a function to get gin.HandlerFunc func RateLimiter(s Store, options *Options) gin.HandlerFunc { if options.ErrorHandler == nil { - options.ErrorHandler = func(c *gin.Context, remaining time.Duration) { - c.Header("X-Rate-Limit-Reset", remaining.String()) + options.ErrorHandler = func(c *gin.Context, info Info) { + c.Header("X-Rate-Limit-Reset", fmt.Sprintf("%.2f", time.Until(info.ResetTime).Seconds())) c.String(429, "Too many requests") } } + if options.BeforeResponse == nil { + options.BeforeResponse = func(c *gin.Context, info Info) { + c.Header("X-Rate-Limit-Remaining", fmt.Sprintf("%v", info.RemainingHits)) + c.Header("X-Rate-Limit-Reset", fmt.Sprintf("%.2f", time.Until(info.ResetTime).Seconds())) + } + } if options.KeyFunc == nil { options.KeyFunc = func(c *gin.Context) string { return c.ClientIP() + c.FullPath() } } return func(c *gin.Context) { - if options.Skip != nil && options.Skip(c) { - c.Next() + key := options.KeyFunc(c) + info := s.Limit(key, c) + options.BeforeResponse(c, info) + if c.IsAborted() { return } - key := options.KeyFunc(c) - limited, remaining := s.Limit(key) - if limited { - options.ErrorHandler(c, remaining) + if info.RateLimited { + options.ErrorHandler(c, info) c.Abort() } else { c.Next() diff --git a/in_memory.go b/in_memory.go index 6e031c8..a8f3db9 100644 --- a/in_memory.go +++ b/in_memory.go @@ -1,6 +1,7 @@ package ratelimit import ( + "github.com/gin-gonic/gin" "sync" "time" ) @@ -26,9 +27,10 @@ type inMemoryStoreType struct { rate int64 limit uint data *sync.Map + skip func(ctx *gin.Context) bool } -func (s *inMemoryStoreType) Limit(key string) (bool, time.Duration) { +func (s *inMemoryStoreType) Limit(key string, c *gin.Context) Info { var u user m, ok := s.data.Load(key) if !ok { @@ -39,14 +41,28 @@ func (s *inMemoryStoreType) Limit(key string) (bool, time.Duration) { if u.ts+s.rate <= time.Now().Unix() { u.tokens = s.limit } - remaining := time.Duration((s.rate - (time.Now().Unix() - u.ts)) * time.Second.Nanoseconds()) + if s.skip != nil && s.skip(c) { + return Info{ + RateLimited: false, + ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - u.ts)) * time.Second.Nanoseconds())), + RemainingHits: u.tokens, + } + } if u.tokens <= 0 { - return true, remaining + return Info{ + RateLimited: true, + ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - u.ts)) * time.Second.Nanoseconds())), + RemainingHits: 0, + } } u.tokens-- u.ts = time.Now().Unix() s.data.Store(key, u) - return false, time.Duration(0) + return Info{ + RateLimited: false, + ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - u.ts)) * time.Second.Nanoseconds())), + RemainingHits: u.tokens, + } } type InMemoryOptions struct { @@ -54,11 +70,13 @@ type InMemoryOptions struct { Rate time.Duration // the amount of requests that can be made every Rate Limit uint + // a function that returns true if the request should not count toward the rate limit + Skip func(*gin.Context) bool } func InMemoryStore(options *InMemoryOptions) Store { data := &sync.Map{} - store := inMemoryStoreType{int64(options.Rate.Seconds()), options.Limit, data} + store := inMemoryStoreType{int64(options.Rate.Seconds()), options.Limit, data, options.Skip} go clearInBackground(data, store.rate) return &store } diff --git a/redis.go b/redis.go index 681f706..9ad1206 100644 --- a/redis.go +++ b/redis.go @@ -16,7 +16,7 @@ type redisStoreType struct { skip func(c *gin.Context) bool } -func (s *redisStoreType) Limit(key string) (bool, time.Duration) { +func (s *redisStoreType) Limit(key string, c *gin.Context) Info { p := s.client.Pipeline() defer p.Close() cmds, _ := s.client.Pipelined(s.ctx, func(pipeliner redis.Pipeliner) error { @@ -36,18 +36,34 @@ func (s *redisStoreType) Limit(key string) (bool, time.Duration) { hits = 0 p.Set(s.ctx, key+"hits", hits, time.Duration(0)) } - remaining := time.Duration((s.rate - (time.Now().Unix() - ts)) * time.Second.Nanoseconds()) + if s.skip != nil && s.skip(c) { + return Info{ + RateLimited: false, + ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - ts)) * time.Second.Nanoseconds())), + RemainingHits: s.limit - uint(hits), + } + } if hits >= int64(s.limit) { _, err = p.Exec(s.ctx) if err != nil { if s.panicOnErr { panic(err) } else { - return false, time.Duration(0) + return Info{ + RateLimited: false, + ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - ts)) * time.Second.Nanoseconds())), + RemainingHits: 0, + } } } - return true, remaining + return Info{ + RateLimited: true, + ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - ts)) * time.Second.Nanoseconds())), + RemainingHits: 0, + } } + ts = time.Now().Unix() + hits++ p.Incr(s.ctx, key+"hits") p.Set(s.ctx, key+"ts", time.Now().Unix(), time.Duration(0)) p.Expire(s.ctx, key+"hits", time.Duration(int64(time.Second)*s.rate*2)) @@ -57,10 +73,18 @@ func (s *redisStoreType) Limit(key string) (bool, time.Duration) { if s.panicOnErr { panic(err) } else { - return false, time.Duration(0) + return Info{ + RateLimited: false, + ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - ts)) * time.Second.Nanoseconds())), + RemainingHits: s.limit - uint(hits), + } } } - return false, time.Duration(0) + return Info{ + RateLimited: false, + ResetTime: time.Now().Add(time.Duration((s.rate - (time.Now().Unix() - ts)) * time.Second.Nanoseconds())), + RemainingHits: s.limit - uint(hits), + } } type RedisOptions struct { @@ -71,6 +95,8 @@ type RedisOptions struct { RedisClient redis.UniversalClient // should gin-rate-limit panic when there is an error with redis PanicOnErr bool + // a function that returns true if the request should not count toward the rate limit + Skip func(*gin.Context) bool } func RedisStore(options *RedisOptions) Store { @@ -80,5 +106,6 @@ func RedisStore(options *RedisOptions) Store { limit: options.Limit, ctx: context.TODO(), panicOnErr: options.PanicOnErr, + skip: options.Skip, } }