Skip to content

Commit c2904bf

Browse files
authored
feat: throttling with retryAfter (#422)
1 parent c992d13 commit c2904bf

File tree

6 files changed

+197
-47
lines changed

6 files changed

+197
-47
lines changed

throttling/lua/gcra.lua

+3-3
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ if remaining < 0 then
5151
current_time_micro,
5252
0, -- allowed
5353
0, -- remaining
54-
tostring(retry_after),
55-
tostring(reset_after),
54+
tonumber(retry_after),
55+
tonumber(reset_after),
5656
}
5757
end
5858

@@ -62,4 +62,4 @@ if reset_after > 0 then
6262
end
6363

6464
local retry_after = -1
65-
return { current_time_micro, cost, remaining, tostring(retry_after), tostring(reset_after) }
65+
return { current_time_micro, cost, remaining, tonumber(retry_after), tonumber(reset_after) }

throttling/lua/sortedset.lua

+9-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ local used_tokens = redis.call('ZCARD', key)
3434

3535
-- If the number of requests is greater than the max requests we hit the limit
3636
if (used_tokens + cost) > tonumber(rate) then
37-
return { current_time_micro, "" }
37+
local next_to_expire = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')[2]
38+
local retry_after = next_to_expire + period - current_time_micro
39+
40+
return { current_time_micro, "", retry_after }
3841
end
3942

4043
-- seed needed to generate random members in case of collision
@@ -63,4 +66,8 @@ end
6366
redis.call('EXPIRE', key, period)
6467

6568
members = members:sub(1, -2) -- remove the last comma
66-
return { current_time_micro, members }
69+
return {
70+
current_time_micro,
71+
members,
72+
0 -- no retry_after
73+
}

throttling/memory_gcra.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,19 @@ type gcra struct {
2020
}
2121

2222
func (g *gcra) limit(ctx context.Context, key string, cost, burst, rate, period int64) (
23-
bool, error,
23+
bool, time.Duration, error,
2424
) {
2525
rl, err := g.getLimiter(key, burst, rate, period)
2626
if err != nil {
27-
return false, err
27+
return false, 0, err
2828
}
2929

30-
limited, _, err := rl.RateLimitCtx(ctx, "key", int(cost))
30+
limited, res, err := rl.RateLimitCtx(ctx, "key", int(cost))
3131
if err != nil {
32-
return false, fmt.Errorf("could not rate limit: %w", err)
32+
return false, 0, fmt.Errorf("could not rate limit: %w", err)
3333
}
3434

35-
return !limited, nil
35+
return !limited, res.RetryAfter, nil
3636
}
3737

3838
func (g *gcra) getLimiter(key string, burst, rate, period int64) (*throttled.GCRARateLimiterCtx, error) {

throttling/memory_gcra_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ func TestMemoryGCRA(t *testing.T) {
1616
rate := int64(1)
1717
period := int64(1)
1818

19-
allowed, err := l.limit(context.Background(), "key", burst+rate, burst, rate, period)
19+
allowed, _, err := l.limit(context.Background(), "key", burst+rate, burst, rate, period)
2020
require.NoError(t, err)
2121
require.True(t, allowed, "it should be able to fill the bucket (burst)")
2222

2323
// next request should be allowed after 5 seconds
2424
start := time.Now()
2525

2626
require.Eventually(t, func() bool {
27-
allowed, err := l.limit(context.Background(), "key", burst, burst, rate, period)
27+
allowed, _, err := l.limit(context.Background(), "key", burst, burst, rate, period)
2828
if err != nil {
2929
t.Logf("Memory GCRA error: %v", err)
3030
return false

throttling/throttling.go

+56-32
Original file line numberDiff line numberDiff line change
@@ -78,128 +78,152 @@ func New(options ...Option) (*Limiter, error) {
7878
// Allow returns true if the limit is not exceeded, false otherwise.
7979
func (l *Limiter) Allow(ctx context.Context, cost, rate, window int64, key string) (
8080
bool, func(context.Context) error, error,
81+
) {
82+
allowed, _, tr, err := l.allow(ctx, cost, rate, window, key)
83+
return allowed, tr, err
84+
}
85+
86+
// AllowAfter returns true if the limit is not exceeded, false otherwise.
87+
// Additionally, it returns the time.Duration until the next allowed request.
88+
func (l *Limiter) AllowAfter(ctx context.Context, cost, rate, window int64, key string) (
89+
bool, time.Duration, func(context.Context) error, error,
90+
) {
91+
return l.allow(ctx, cost, rate, window, key)
92+
}
93+
94+
func (l *Limiter) allow(ctx context.Context, cost, rate, window int64, key string) (
95+
bool, time.Duration, func(context.Context) error, error,
8196
) {
8297
if cost < 1 {
83-
return false, nil, fmt.Errorf("cost must be greater than 0")
98+
return false, 0, nil, fmt.Errorf("cost must be greater than 0")
8499
}
85100
if rate < 1 {
86-
return false, nil, fmt.Errorf("rate must be greater than 0")
101+
return false, 0, nil, fmt.Errorf("rate must be greater than 0")
87102
}
88103
if window < 1 {
89-
return false, nil, fmt.Errorf("window must be greater than 0")
104+
return false, 0, nil, fmt.Errorf("window must be greater than 0")
90105
}
91106
if key == "" {
92-
return false, nil, fmt.Errorf("key must not be empty")
107+
return false, 0, nil, fmt.Errorf("key must not be empty")
93108
}
94109

95110
if l.redisSpeaker != nil {
96111
if l.useGCRA {
97112
defer l.getTimer(key, "redis-gcra", rate, window)()
98-
_, allowed, tr, err := l.redisGCRA(ctx, cost, rate, window, key)
99-
return allowed, tr, err
113+
_, allowed, retryAfter, tr, err := l.redisGCRA(ctx, cost, rate, window, key)
114+
return allowed, retryAfter, tr, err
100115
}
101116

102117
defer l.getTimer(key, "redis-sorted-set", rate, window)()
103-
_, allowed, tr, err := l.redisSortedSet(ctx, cost, rate, window, key)
104-
return allowed, tr, err
118+
_, allowed, retryAfter, tr, err := l.redisSortedSet(ctx, cost, rate, window, key)
119+
return allowed, retryAfter, tr, err
105120
}
106121

107122
defer l.getTimer(key, "gcra", rate, window)()
108-
return l.gcraLimit(ctx, cost, rate, window, key)
123+
allowed, retryAfter, tr, err := l.gcraLimit(ctx, cost, rate, window, key)
124+
return allowed, retryAfter, tr, err
109125
}
110126

111127
func (l *Limiter) redisSortedSet(ctx context.Context, cost, rate, window int64, key string) (
112-
time.Duration, bool, func(context.Context) error, error,
128+
time.Duration, bool, time.Duration, func(context.Context) error, error,
113129
) {
114130
res, err := sortedSetScript.Run(ctx, l.redisSpeaker, []string{key}, cost, rate, window).Result()
115131
if err != nil {
116-
return 0, false, nil, fmt.Errorf("could not run SortedSet Redis script: %v", err)
132+
return 0, false, 0, nil, fmt.Errorf("could not run SortedSet Redis script: %v", err)
117133
}
118134

119135
result, ok := res.([]interface{})
120136
if !ok {
121-
return 0, false, nil, fmt.Errorf("unexpected result from SortedSet Redis script of type %T: %v", res, res)
137+
return 0, false, 0, nil, fmt.Errorf("unexpected result from SortedSet Redis script of type %T: %v", res, res)
122138
}
123-
if len(result) != 2 {
124-
return 0, false, nil, fmt.Errorf("unexpected result from SortedSet Redis script of length %d: %+v", len(result), result)
139+
if len(result) != 3 {
140+
return 0, false, 0, nil, fmt.Errorf("unexpected result from SortedSet Redis script of length %d: %+v", len(result), result)
125141
}
126142

127143
t, ok := result[0].(int64)
128144
if !ok {
129-
return 0, false, nil, fmt.Errorf("unexpected result[0] from SortedSet Redis script of type %T: %v", result[0], result[0])
145+
return 0, false, 0, nil, fmt.Errorf("unexpected result[0] from SortedSet Redis script of type %T: %v", result[0], result[0])
130146
}
131147
redisTime := time.Duration(t) * time.Microsecond
132148

133149
members, ok := result[1].(string)
134150
if !ok {
135-
return redisTime, false, nil, fmt.Errorf("unexpected result[1] from SortedSet Redis script of type %T: %v", result[1], result[1])
151+
return redisTime, false, 0, nil, fmt.Errorf("unexpected result[1] from SortedSet Redis script of type %T: %v", result[1], result[1])
136152
}
137153
if members == "" { // limit exceeded
138-
return redisTime, false, nil, nil
154+
retryAfter, ok := result[2].(int64)
155+
if !ok {
156+
return redisTime, false, 0, nil, fmt.Errorf("unexpected result[2] from SortedSet Redis script of type %T: %v", result[2], result[2])
157+
}
158+
return redisTime, false, time.Duration(retryAfter) * time.Microsecond, nil, nil
139159
}
140160

141161
r := &sortedSetRedisReturn{
142162
key: key,
143163
members: strings.Split(members, ","),
144164
remover: l.redisSpeaker,
145165
}
146-
return redisTime, true, r.Return, nil
166+
return redisTime, true, 0, r.Return, nil
147167
}
148168

149169
func (l *Limiter) redisGCRA(ctx context.Context, cost, rate, window int64, key string) (
150-
time.Duration, bool, func(context.Context) error, error,
170+
time.Duration, bool, time.Duration, func(context.Context) error, error,
151171
) {
152172
burst := rate
153173
if l.gcraBurst > 0 {
154174
burst = l.gcraBurst
155175
}
156176
res, err := gcraRedisScript.Run(ctx, l.redisSpeaker, []string{key}, burst, rate, window, cost).Result()
157177
if err != nil {
158-
return 0, false, nil, fmt.Errorf("could not run GCRA Redis script: %v", err)
178+
return 0, false, 0, nil, fmt.Errorf("could not run GCRA Redis script: %v", err)
159179
}
160180

161-
result, ok := res.([]interface{})
181+
result, ok := res.([]any)
162182
if !ok {
163-
return 0, false, nil, fmt.Errorf("unexpected result from GCRA Redis script of type %T: %v", res, res)
183+
return 0, false, 0, nil, fmt.Errorf("unexpected result from GCRA Redis script of type %T: %v", res, res)
164184
}
165185
if len(result) != 5 {
166-
return 0, false, nil, fmt.Errorf("unexpected result from GCRA Redis scrip of length %d: %+v", len(result), result)
186+
return 0, false, 0, nil, fmt.Errorf("unexpected result from GCRA Redis scrip of length %d: %+v", len(result), result)
167187
}
168188

169189
t, ok := result[0].(int64)
170190
if !ok {
171-
return 0, false, nil, fmt.Errorf("unexpected result[0] from GCRA Redis script of type %T: %v", result[0], result[0])
191+
return 0, false, 0, nil, fmt.Errorf("unexpected result[0] from GCRA Redis script of type %T: %v", result[0], result[0])
172192
}
173193
redisTime := time.Duration(t) * time.Microsecond
174194

175195
allowed, ok := result[1].(int64)
176196
if !ok {
177-
return redisTime, false, nil, fmt.Errorf("unexpected result[1] from GCRA Redis script of type %T: %v", result[1], result[1])
197+
return redisTime, false, 0, nil, fmt.Errorf("unexpected result[1] from GCRA Redis script of type %T: %v", result[1], result[1])
178198
}
179199
if allowed < 1 { // limit exceeded
180-
return redisTime, false, nil, nil
200+
retryAfter, ok := result[3].(int64)
201+
if !ok {
202+
return redisTime, false, 0, nil, fmt.Errorf("unexpected result[3] from GCRA Redis script of type %T: %v", result[3], result[3])
203+
}
204+
return redisTime, false, time.Duration(retryAfter) * time.Microsecond, nil, nil
181205
}
182206

183207
r := &unsupportedReturn{}
184-
return redisTime, true, r.Return, nil
208+
return redisTime, true, 0, r.Return, nil
185209
}
186210

187211
func (l *Limiter) gcraLimit(ctx context.Context, cost, rate, window int64, key string) (
188-
bool, func(context.Context) error, error,
212+
bool, time.Duration, func(context.Context) error, error,
189213
) {
190214
burst := rate
191215
if l.gcraBurst > 0 {
192216
burst = l.gcraBurst
193217
}
194-
allowed, err := l.gcra.limit(ctx, key, cost, burst, rate, window)
218+
allowed, retryAfter, err := l.gcra.limit(ctx, key, cost, burst, rate, window)
195219
if err != nil {
196-
return false, nil, fmt.Errorf("could not limit: %w", err)
220+
return false, 0, nil, fmt.Errorf("could not limit: %w", err)
197221
}
198222
if !allowed {
199-
return false, nil, nil // limit exceeded
223+
return false, retryAfter, nil, nil // limit exceeded
200224
}
201225
r := &unsupportedReturn{}
202-
return true, r.Return, nil
226+
return true, 0, r.Return, nil
203227
}
204228

205229
func (l *Limiter) getTimer(key, algo string, rate, window int64) func() {

0 commit comments

Comments
 (0)