diff --git a/README.md b/README.md index f0bcfc1..19ae038 100644 --- a/README.md +++ b/README.md @@ -40,9 +40,13 @@ func errorHandler(c *gin.Context, remaining time.Duration) { func main() { server := gin.Default() // This makes it so each ip can only make 5 requests per second - store := ratelimit.RedisStore(time.Second, 5, redis.NewClient(&redis.Options{ - Addr: "localhost:7680", - }), false) + store := ratelimit.RedisStore(&ratelimit.RedisOptions{ + RedisClient: redis.NewClient(&redis.Options{ + Addr: "localhost:7680", + }), + Rate: time.Second, + Limit: 5, + }) mw := ratelimit.RateLimiter(keyFunc, errorHandler, store) server.GET("/", mw, func(c *gin.Context) { c.String(200, "Hello World") @@ -75,7 +79,10 @@ func errorHandler(c *gin.Context, remaining time.Duration) { func main() { server := gin.Default() // This makes it so each ip can only make 5 requests per second - store := ratelimit.InMemoryStore(time.Second, 5) + store := ratelimit.InMemoryStore(&ratelimit.InMemoryOptions{ + Rate: time.Second, + Limit: 5, + }) mw := ratelimit.RateLimiter(keyFunc, errorHandler, store) server.GET("/", mw, func(c *gin.Context) { c.String(200, "Hello World") @@ -91,18 +98,27 @@ Custom Store Example ```go package main -import "time" +import ( + "github.com/gin-gonic/gin" + "time" +) type CustomStore struct { } -// Your store must have a method called Limit that takes a key and returns a bool +// Your store must have a method called Limit that takes a key and returns a bool, time.Duration func (s *CustomStore) Limit(key string) (bool, time.Duration) { // Do your rate limit logic, and return true if the user went over the rate limit, otherwise return false // Return the amount of time the client needs to wait to make a new request if UserWentOverLimit { - return true + return true, remaining } + return false, remaining +} + +// Your store must have a method called Skip that takes a *gin.Context and returns a bool +func (s *CustomStore) Skip(c *gin.Context) bool { + // return true if you dont want this request to count toward the users rate limit return false } ``` \ No newline at end of file diff --git a/gin_rate_limit.go b/gin_rate_limit.go new file mode 100644 index 0000000..41537ad --- /dev/null +++ b/gin_rate_limit.go @@ -0,0 +1,28 @@ +package ratelimit + +import ( + "github.com/gin-gonic/gin" + "time" +) + +type Store interface { + Limit(key string) (bool, time.Duration) + Skip(c *gin.Context) bool +} + +func RateLimiter(keyFunc func(c *gin.Context) string, errorHandler func(c *gin.Context, remaining time.Duration), s Store) func(ctx *gin.Context) { + return func(c *gin.Context) { + if s.Skip(c) { + c.Next() + return + } + key := keyFunc(c) + limited, remaining := s.Limit(key) + if limited { + errorHandler(c, remaining) + c.Abort() + } else { + c.Next() + } + } +} diff --git a/go.mod b/go.mod index bb31750..597accb 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/JGLTechnologies/gin-rate-limit -go 1.18 +go 1.17 require ( github.com/gin-gonic/gin v1.8.1 diff --git a/in_memory.go b/in_memory.go index baa3dec..2bfb5e4 100644 --- a/in_memory.go +++ b/in_memory.go @@ -8,12 +8,12 @@ import ( type user struct { ts int64 - tokens int + tokens uint } func clearInBackground(data *sync.Map, rate int64) { for { - data.Range(func(k, v any) bool { + data.Range(func(k, v interface{}) bool { if v.(user).ts+rate <= time.Now().Unix() { data.Delete(k) } @@ -23,13 +23,14 @@ func clearInBackground(data *sync.Map, rate int64) { } } -type InMemoryStoreType struct { +type inMemoryStoreType struct { rate int64 - limit int + limit uint data *sync.Map + skip func(c *gin.Context) bool } -func (s *InMemoryStoreType) Limit(key string) (bool, time.Duration) { +func (s *inMemoryStoreType) Limit(key string) (bool, time.Duration) { var u user m, ok := s.data.Load(key) if !ok { @@ -50,26 +51,23 @@ func (s *InMemoryStoreType) Limit(key string) (bool, time.Duration) { return false, time.Duration(0) } -type store interface { - Limit(key string) (bool, time.Duration) +func (s *inMemoryStoreType) Skip(c *gin.Context) bool { + if s.skip != nil { + return s.skip(c) + } else { + return false + } +} + +type InMemoryOptions struct { + Rate time.Duration + Limit uint + Skip func(c *gin.Context) bool } -func InMemoryStore(rate time.Duration, limit int) *InMemoryStoreType { +func InMemoryStore(options *InMemoryOptions) Store { data := &sync.Map{} - store := InMemoryStoreType{int64(rate.Seconds()), limit, data} + store := inMemoryStoreType{int64(options.Rate.Seconds()), options.Limit, data, options.Skip} go clearInBackground(data, store.rate) return &store } - -func RateLimiter(keyFunc func(c *gin.Context) string, errorHandler func(c *gin.Context, remaining time.Duration), s store) func(ctx *gin.Context) { - return func(c *gin.Context) { - key := keyFunc(c) - limited, remaining := s.Limit(key) - if limited { - errorHandler(c, remaining) - c.Abort() - } else { - c.Next() - } - } -} diff --git a/redis.go b/redis.go index fa93ade..95c3424 100644 --- a/redis.go +++ b/redis.go @@ -2,19 +2,21 @@ package ratelimit import ( "context" + "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" "time" ) -type RedisStoreType struct { +type redisStoreType struct { rate int64 - limit int + limit uint client redis.UniversalClient ctx context.Context panicOnErr bool + skip func(c *gin.Context) bool } -func (s *RedisStoreType) Limit(key string) (bool, time.Duration) { +func (s *redisStoreType) Limit(key string) (bool, time.Duration) { p := s.client.Pipeline() defer p.Close() cmds, _ := s.client.Pipelined(s.ctx, func(pipeliner redis.Pipeliner) error { @@ -61,6 +63,29 @@ func (s *RedisStoreType) Limit(key string) (bool, time.Duration) { return false, time.Duration(0) } -func RedisStore(rate time.Duration, limit int, redisClient redis.UniversalClient, panicOnErr bool) *RedisStoreType { - return &RedisStoreType{client: redisClient, rate: int64(rate.Seconds()), limit: limit, ctx: context.TODO(), panicOnErr: panicOnErr} +func (s *redisStoreType) Skip(c *gin.Context) bool { + if s.skip != nil { + return s.skip(c) + } else { + return false + } +} + +type RedisOptions struct { + Rate time.Duration + Limit uint + RedisClient redis.UniversalClient + Skip func(c *gin.Context) bool + PanicOnErr bool +} + +func RedisStore(options *RedisOptions) Store { + return &redisStoreType{ + client: options.RedisClient, + rate: int64(options.Rate.Seconds()), + limit: options.Limit, + ctx: context.TODO(), + panicOnErr: options.PanicOnErr, + skip: options.Skip, + } }