diff --git a/README.md b/README.md
index 19ae038..80e5612 100644
--- a/README.md
+++ b/README.md
@@ -4,10 +4,10 @@
# gin-rate-limit
-gin-rate-limit is a rate limiter for the gin framework. By default, it can
-only store rate limit info in memory and with redis. If you want to store it somewhere else you can make your own store
-or use third party stores. The library is new so there are no third party stores yet, so I would appreciate if someone
-could make one.
+gin-rate-limit is a rate limiter for the gin framework. By default, it
+can only store rate limit info in memory and with redis. If you want to store it somewhere else you can make your own
+store or use third party stores. The library is new so there are no third party stores yet, so I would appreciate if
+someone could make one.
Install
@@ -44,10 +44,13 @@ func main() {
RedisClient: redis.NewClient(&redis.Options{
Addr: "localhost:7680",
}),
- Rate: time.Second,
+ Rate: time.Second,
Limit: 5,
+ })
+ mw := ratelimit.RateLimiter(store, &ratelimit.Options{
+ ErrorHanlder: errorHandler,
+ KeyFunc: keyfunc,
})
- mw := ratelimit.RateLimiter(keyFunc, errorHandler, store)
server.GET("/", mw, func(c *gin.Context) {
c.String(200, "Hello World")
})
@@ -80,10 +83,13 @@ func main() {
server := gin.Default()
// This makes it so each ip can only make 5 requests per second
store := ratelimit.InMemoryStore(&ratelimit.InMemoryOptions{
- Rate: time.Second,
+ Rate: time.Second,
Limit: 5,
- })
- mw := ratelimit.RateLimiter(keyFunc, errorHandler, store)
+ })
+ mw := ratelimit.RateLimiter(store, &ratelimit.Options{
+ ErrorHanlder: errorHandler,
+ KeyFunc: keyfunc,
+ })
server.GET("/", mw, func(c *gin.Context) {
c.String(200, "Hello World")
})
@@ -93,6 +99,12 @@ func main() {
+# Custom Stores
+
+Unlike most rate limit libraries that support third party stores, gin-rate-limit only requires your store to have one
+method and lets you rate limit however you want. The default stores may not have exactly what you need, but it is easy
+to take one of the default stores and modify them. If you think your store could be useful to other people push it to GitHub.
+
Custom Store Example
```go
@@ -115,10 +127,4 @@ func (s *CustomStore) Limit(key string) (bool, time.Duration) {
}
return false, remaining
}
-
-// Your store must have a method called Skip that takes a *gin.Context and returns a bool
-func (s *CustomStore) Skip(c *gin.Context) bool {
- // return true if you dont want this request to count toward the users rate limit
- return false
-}
```
\ No newline at end of file
diff --git a/gin_rate_limit.go b/gin_rate_limit.go
index 217d85c..ed307f1 100644
--- a/gin_rate_limit.go
+++ b/gin_rate_limit.go
@@ -8,24 +8,37 @@ import (
type Store interface {
// Limit takes in a key and should return whether that key is allowed to make another request
Limit(key string) (bool, time.Duration)
- // Skip takes in a *gin.Context and should return whether the rate limiting should be skipped for this request
- Skip(c *gin.Context) bool
+}
+
+type Options struct {
+ ErrorHandler func(*gin.Context, time.Duration)
+ KeyFunc func(*gin.Context) string
+ // a function that returns true if the request should not count toward the rate limit
+ Skip func(*gin.Context) bool
}
// RateLimiter is a function to get gin.HandlerFunc
-// keyFunc: takes in *gin.Context and return a string
-// errorHandler: takes in *gin.Context and time.Duration
-// store: Store
-func RateLimiter(keyFunc func(c *gin.Context) string, errorHandler func(c *gin.Context, remaining time.Duration), s Store) gin.HandlerFunc {
+func RateLimiter(s Store, options *Options) gin.HandlerFunc {
+ if options.ErrorHandler == nil {
+ options.ErrorHandler = func(c *gin.Context, remaining time.Duration) {
+ c.Header("X-Rate-Limit-Reset", remaining.String())
+ c.String(429, "Too many requests")
+ }
+ }
+ if options.KeyFunc == nil {
+ options.KeyFunc = func(c *gin.Context) string {
+ return c.ClientIP() + c.FullPath()
+ }
+ }
return func(c *gin.Context) {
- if s.Skip(c) {
+ if options.Skip != nil && options.Skip(c) {
c.Next()
return
}
- key := keyFunc(c)
+ key := options.KeyFunc(c)
limited, remaining := s.Limit(key)
if limited {
- errorHandler(c, remaining)
+ options.ErrorHandler(c, remaining)
c.Abort()
} else {
c.Next()
diff --git a/in_memory.go b/in_memory.go
index 48f25af..6e031c8 100644
--- a/in_memory.go
+++ b/in_memory.go
@@ -1,7 +1,6 @@
package ratelimit
import (
- "github.com/gin-gonic/gin"
"sync"
"time"
)
@@ -27,7 +26,6 @@ type inMemoryStoreType struct {
rate int64
limit uint
data *sync.Map
- skip func(c *gin.Context) bool
}
func (s *inMemoryStoreType) Limit(key string) (bool, time.Duration) {
@@ -51,26 +49,16 @@ func (s *inMemoryStoreType) Limit(key string) (bool, time.Duration) {
return false, time.Duration(0)
}
-func (s *inMemoryStoreType) Skip(c *gin.Context) bool {
- if s.skip != nil {
- return s.skip(c)
- } else {
- return false
- }
-}
-
type InMemoryOptions struct {
// the user can make Limit amount of requests every Rate
Rate time.Duration
// the amount of requests that can be made every Rate
Limit uint
- // takes in a *gin.Context and should return whether the rate limiting should be skipped for this request
- Skip func(c *gin.Context) bool
}
func InMemoryStore(options *InMemoryOptions) Store {
data := &sync.Map{}
- store := inMemoryStoreType{int64(options.Rate.Seconds()), options.Limit, data, options.Skip}
+ store := inMemoryStoreType{int64(options.Rate.Seconds()), options.Limit, data}
go clearInBackground(data, store.rate)
return &store
}
diff --git a/redis.go b/redis.go
index 192c48a..681f706 100644
--- a/redis.go
+++ b/redis.go
@@ -63,21 +63,11 @@ func (s *redisStoreType) Limit(key string) (bool, time.Duration) {
return false, time.Duration(0)
}
-func (s *redisStoreType) Skip(c *gin.Context) bool {
- if s.skip != nil {
- return s.skip(c)
- } else {
- return false
- }
-}
-
type RedisOptions struct {
// the user can make Limit amount of requests every Rate
Rate time.Duration
// the amount of requests that can be made every Rate
- Limit uint
- // takes in a *gin.Context and should return whether the rate limiting should be skipped for this request
- Skip func(c *gin.Context) bool
+ Limit uint
RedisClient redis.UniversalClient
// should gin-rate-limit panic when there is an error with redis
PanicOnErr bool
@@ -90,6 +80,5 @@ func RedisStore(options *RedisOptions) Store {
limit: options.Limit,
ctx: context.TODO(),
panicOnErr: options.PanicOnErr,
- skip: options.Skip,
}
}