@@ -78,128 +78,152 @@ func New(options ...Option) (*Limiter, error) {
78
78
// Allow returns true if the limit is not exceeded, false otherwise.
79
79
func (l * Limiter ) Allow (ctx context.Context , cost , rate , window int64 , key string ) (
80
80
bool , func (context.Context ) error , error ,
81
+ ) {
82
+ allowed , _ , tr , err := l .allow (ctx , cost , rate , window , key )
83
+ return allowed , tr , err
84
+ }
85
+
86
+ // AllowAfter returns true if the limit is not exceeded, false otherwise.
87
+ // Additionally, it returns the time.Duration until the next allowed request.
88
+ func (l * Limiter ) AllowAfter (ctx context.Context , cost , rate , window int64 , key string ) (
89
+ bool , time.Duration , func (context.Context ) error , error ,
90
+ ) {
91
+ return l .allow (ctx , cost , rate , window , key )
92
+ }
93
+
94
+ func (l * Limiter ) allow (ctx context.Context , cost , rate , window int64 , key string ) (
95
+ bool , time.Duration , func (context.Context ) error , error ,
81
96
) {
82
97
if cost < 1 {
83
- return false , nil , fmt .Errorf ("cost must be greater than 0" )
98
+ return false , 0 , nil , fmt .Errorf ("cost must be greater than 0" )
84
99
}
85
100
if rate < 1 {
86
- return false , nil , fmt .Errorf ("rate must be greater than 0" )
101
+ return false , 0 , nil , fmt .Errorf ("rate must be greater than 0" )
87
102
}
88
103
if window < 1 {
89
- return false , nil , fmt .Errorf ("window must be greater than 0" )
104
+ return false , 0 , nil , fmt .Errorf ("window must be greater than 0" )
90
105
}
91
106
if key == "" {
92
- return false , nil , fmt .Errorf ("key must not be empty" )
107
+ return false , 0 , nil , fmt .Errorf ("key must not be empty" )
93
108
}
94
109
95
110
if l .redisSpeaker != nil {
96
111
if l .useGCRA {
97
112
defer l .getTimer (key , "redis-gcra" , rate , window )()
98
- _ , allowed , tr , err := l .redisGCRA (ctx , cost , rate , window , key )
99
- return allowed , tr , err
113
+ _ , allowed , retryAfter , tr , err := l .redisGCRA (ctx , cost , rate , window , key )
114
+ return allowed , retryAfter , tr , err
100
115
}
101
116
102
117
defer l .getTimer (key , "redis-sorted-set" , rate , window )()
103
- _ , allowed , tr , err := l .redisSortedSet (ctx , cost , rate , window , key )
104
- return allowed , tr , err
118
+ _ , allowed , retryAfter , tr , err := l .redisSortedSet (ctx , cost , rate , window , key )
119
+ return allowed , retryAfter , tr , err
105
120
}
106
121
107
122
defer l .getTimer (key , "gcra" , rate , window )()
108
- return l .gcraLimit (ctx , cost , rate , window , key )
123
+ allowed , retryAfter , tr , err := l .gcraLimit (ctx , cost , rate , window , key )
124
+ return allowed , retryAfter , tr , err
109
125
}
110
126
111
127
func (l * Limiter ) redisSortedSet (ctx context.Context , cost , rate , window int64 , key string ) (
112
- time.Duration , bool , func (context.Context ) error , error ,
128
+ time.Duration , bool , time. Duration , func (context.Context ) error , error ,
113
129
) {
114
130
res , err := sortedSetScript .Run (ctx , l .redisSpeaker , []string {key }, cost , rate , window ).Result ()
115
131
if err != nil {
116
- return 0 , false , nil , fmt .Errorf ("could not run SortedSet Redis script: %v" , err )
132
+ return 0 , false , 0 , nil , fmt .Errorf ("could not run SortedSet Redis script: %v" , err )
117
133
}
118
134
119
135
result , ok := res .([]interface {})
120
136
if ! ok {
121
- return 0 , false , nil , fmt .Errorf ("unexpected result from SortedSet Redis script of type %T: %v" , res , res )
137
+ return 0 , false , 0 , nil , fmt .Errorf ("unexpected result from SortedSet Redis script of type %T: %v" , res , res )
122
138
}
123
- if len (result ) != 2 {
124
- return 0 , false , nil , fmt .Errorf ("unexpected result from SortedSet Redis script of length %d: %+v" , len (result ), result )
139
+ if len (result ) != 3 {
140
+ return 0 , false , 0 , nil , fmt .Errorf ("unexpected result from SortedSet Redis script of length %d: %+v" , len (result ), result )
125
141
}
126
142
127
143
t , ok := result [0 ].(int64 )
128
144
if ! ok {
129
- return 0 , false , nil , fmt .Errorf ("unexpected result[0] from SortedSet Redis script of type %T: %v" , result [0 ], result [0 ])
145
+ return 0 , false , 0 , nil , fmt .Errorf ("unexpected result[0] from SortedSet Redis script of type %T: %v" , result [0 ], result [0 ])
130
146
}
131
147
redisTime := time .Duration (t ) * time .Microsecond
132
148
133
149
members , ok := result [1 ].(string )
134
150
if ! ok {
135
- return redisTime , false , nil , fmt .Errorf ("unexpected result[1] from SortedSet Redis script of type %T: %v" , result [1 ], result [1 ])
151
+ return redisTime , false , 0 , nil , fmt .Errorf ("unexpected result[1] from SortedSet Redis script of type %T: %v" , result [1 ], result [1 ])
136
152
}
137
153
if members == "" { // limit exceeded
138
- return redisTime , false , nil , nil
154
+ retryAfter , ok := result [2 ].(int64 )
155
+ if ! ok {
156
+ return redisTime , false , 0 , nil , fmt .Errorf ("unexpected result[2] from SortedSet Redis script of type %T: %v" , result [2 ], result [2 ])
157
+ }
158
+ return redisTime , false , time .Duration (retryAfter ) * time .Microsecond , nil , nil
139
159
}
140
160
141
161
r := & sortedSetRedisReturn {
142
162
key : key ,
143
163
members : strings .Split (members , "," ),
144
164
remover : l .redisSpeaker ,
145
165
}
146
- return redisTime , true , r .Return , nil
166
+ return redisTime , true , 0 , r .Return , nil
147
167
}
148
168
149
169
func (l * Limiter ) redisGCRA (ctx context.Context , cost , rate , window int64 , key string ) (
150
- time.Duration , bool , func (context.Context ) error , error ,
170
+ time.Duration , bool , time. Duration , func (context.Context ) error , error ,
151
171
) {
152
172
burst := rate
153
173
if l .gcraBurst > 0 {
154
174
burst = l .gcraBurst
155
175
}
156
176
res , err := gcraRedisScript .Run (ctx , l .redisSpeaker , []string {key }, burst , rate , window , cost ).Result ()
157
177
if err != nil {
158
- return 0 , false , nil , fmt .Errorf ("could not run GCRA Redis script: %v" , err )
178
+ return 0 , false , 0 , nil , fmt .Errorf ("could not run GCRA Redis script: %v" , err )
159
179
}
160
180
161
- result , ok := res .([]interface {} )
181
+ result , ok := res .([]any )
162
182
if ! ok {
163
- return 0 , false , nil , fmt .Errorf ("unexpected result from GCRA Redis script of type %T: %v" , res , res )
183
+ return 0 , false , 0 , nil , fmt .Errorf ("unexpected result from GCRA Redis script of type %T: %v" , res , res )
164
184
}
165
185
if len (result ) != 5 {
166
- return 0 , false , nil , fmt .Errorf ("unexpected result from GCRA Redis scrip of length %d: %+v" , len (result ), result )
186
+ return 0 , false , 0 , nil , fmt .Errorf ("unexpected result from GCRA Redis scrip of length %d: %+v" , len (result ), result )
167
187
}
168
188
169
189
t , ok := result [0 ].(int64 )
170
190
if ! ok {
171
- return 0 , false , nil , fmt .Errorf ("unexpected result[0] from GCRA Redis script of type %T: %v" , result [0 ], result [0 ])
191
+ return 0 , false , 0 , nil , fmt .Errorf ("unexpected result[0] from GCRA Redis script of type %T: %v" , result [0 ], result [0 ])
172
192
}
173
193
redisTime := time .Duration (t ) * time .Microsecond
174
194
175
195
allowed , ok := result [1 ].(int64 )
176
196
if ! ok {
177
- return redisTime , false , nil , fmt .Errorf ("unexpected result[1] from GCRA Redis script of type %T: %v" , result [1 ], result [1 ])
197
+ return redisTime , false , 0 , nil , fmt .Errorf ("unexpected result[1] from GCRA Redis script of type %T: %v" , result [1 ], result [1 ])
178
198
}
179
199
if allowed < 1 { // limit exceeded
180
- return redisTime , false , nil , nil
200
+ retryAfter , ok := result [3 ].(int64 )
201
+ if ! ok {
202
+ return redisTime , false , 0 , nil , fmt .Errorf ("unexpected result[3] from GCRA Redis script of type %T: %v" , result [3 ], result [3 ])
203
+ }
204
+ return redisTime , false , time .Duration (retryAfter ) * time .Microsecond , nil , nil
181
205
}
182
206
183
207
r := & unsupportedReturn {}
184
- return redisTime , true , r .Return , nil
208
+ return redisTime , true , 0 , r .Return , nil
185
209
}
186
210
187
211
func (l * Limiter ) gcraLimit (ctx context.Context , cost , rate , window int64 , key string ) (
188
- bool , func (context.Context ) error , error ,
212
+ bool , time. Duration , func (context.Context ) error , error ,
189
213
) {
190
214
burst := rate
191
215
if l .gcraBurst > 0 {
192
216
burst = l .gcraBurst
193
217
}
194
- allowed , err := l .gcra .limit (ctx , key , cost , burst , rate , window )
218
+ allowed , retryAfter , err := l .gcra .limit (ctx , key , cost , burst , rate , window )
195
219
if err != nil {
196
- return false , nil , fmt .Errorf ("could not limit: %w" , err )
220
+ return false , 0 , nil , fmt .Errorf ("could not limit: %w" , err )
197
221
}
198
222
if ! allowed {
199
- return false , nil , nil // limit exceeded
223
+ return false , retryAfter , nil , nil // limit exceeded
200
224
}
201
225
r := & unsupportedReturn {}
202
- return true , r .Return , nil
226
+ return true , 0 , r .Return , nil
203
227
}
204
228
205
229
func (l * Limiter ) getTimer (key , algo string , rate , window int64 ) func () {
0 commit comments