Skip to content

Commit f59588a

Browse files
authored
feat: send bucket update when rate limit applied (#1277)
1 parent fdf03de commit f59588a

File tree

2 files changed

+62
-9
lines changed

2 files changed

+62
-9
lines changed

waku/v2/api/publish/rln_rate_limiting.go

+21-6
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ import (
1111

1212
var ErrRateLimited = errors.New("rate limit exceeded")
1313

14-
const RlnLimiterCapacity = 100
15-
const RlnLimiterRefillInterval = 10 * time.Minute
14+
const DefaultRlnLimiterCapacity = 600
15+
const DefaultRlnLimiterRefillInterval = 10 * time.Minute
1616

1717
// RlnRateLimiter is used to rate limit the outgoing messages,
1818
// The capacity and refillInterval comes from RLN contract configuration.
@@ -22,15 +22,23 @@ type RlnRateLimiter struct {
2222
tokens int
2323
refillInterval time.Duration
2424
lastRefill time.Time
25+
updateCh chan RlnRateLimitState
26+
}
27+
28+
// RlnRateLimitState includes the information that need to be persisted in database.
29+
type RlnRateLimitState struct {
30+
RemainingTokens int
31+
LastRefill time.Time
2532
}
2633

2734
// NewRlnPublishRateLimiter creates a new rate limiter, starts with a full capacity bucket.
28-
func NewRlnRateLimiter(capacity int, refillInterval time.Duration) *RlnRateLimiter {
35+
func NewRlnRateLimiter(capacity int, refillInterval time.Duration, state RlnRateLimitState, updateCh chan RlnRateLimitState) *RlnRateLimiter {
2936
return &RlnRateLimiter{
3037
capacity: capacity,
31-
tokens: capacity, // Start with a full bucket
38+
tokens: state.RemainingTokens,
3239
refillInterval: refillInterval,
33-
lastRefill: time.Now(),
40+
lastRefill: state.LastRefill,
41+
updateCh: updateCh,
3442
}
3543
}
3644

@@ -42,19 +50,26 @@ func (rl *RlnRateLimiter) Allow() bool {
4250
// Refill tokens if the refill interval has passed
4351
now := time.Now()
4452
if now.Sub(rl.lastRefill) >= rl.refillInterval {
45-
rl.tokens = rl.capacity // Refill the bucket
53+
rl.tokens = rl.capacity
4654
rl.lastRefill = now
55+
rl.sendUpdate()
4756
}
4857

4958
// Check if there are tokens available
5059
if rl.tokens > 0 {
5160
rl.tokens--
61+
rl.sendUpdate()
5262
return true
5363
}
5464

5565
return false
5666
}
5767

68+
// sendUpdate sends the latest token state to the update channel.
69+
func (rl *RlnRateLimiter) sendUpdate() {
70+
rl.updateCh <- RlnRateLimitState{RemainingTokens: rl.tokens, LastRefill: rl.lastRefill}
71+
}
72+
5873
func (rl *RlnRateLimiter) Check(ctx context.Context, logger *zap.Logger) error {
5974
if rl.Allow() {
6075
return nil

waku/v2/api/publish/rln_rate_limiting_test.go

+41-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package publish
22

33
import (
44
"context"
5+
"sync"
56
"testing"
67
"time"
78

@@ -10,17 +11,54 @@ import (
1011
)
1112

1213
func TestRlnRateLimit(t *testing.T) {
13-
r := NewRlnRateLimiter(3, 5*time.Second)
14+
updateCh := make(chan RlnRateLimitState, 10)
15+
refillTime := time.Now()
16+
capacity := 3
17+
state := RlnRateLimitState{
18+
RemainingTokens: capacity,
19+
LastRefill: refillTime,
20+
}
21+
r := NewRlnRateLimiter(capacity, 5*time.Second, state, updateCh)
1422
l := utils.Logger()
1523

16-
for i := 0; i < 3; i++ {
24+
ctx, cancel := context.WithCancel(context.Background())
25+
defer cancel()
26+
27+
sleepDuration := 6 * time.Second
28+
var mu sync.Mutex
29+
go func(ctx context.Context, ch chan RlnRateLimitState) {
30+
usedToken := 0
31+
for {
32+
select {
33+
case update := <-ch:
34+
mu.Lock()
35+
if update.LastRefill != refillTime {
36+
usedToken = 0
37+
require.WithinDuration(t, refillTime.Add(sleepDuration), update.LastRefill, time.Second, "Last refill timestamp is incorrect")
38+
require.Equal(t, update.RemainingTokens, capacity)
39+
continue
40+
}
41+
usedToken++
42+
require.Equal(t, update.RemainingTokens, capacity-usedToken)
43+
mu.Unlock()
44+
case <-ctx.Done():
45+
return
46+
}
47+
}
48+
}(ctx, updateCh)
49+
50+
for i := 0; i < capacity; i++ {
1751
require.NoError(t, r.Check(context.Background(), l))
1852
}
1953
require.ErrorIs(t, r.Check(context.Background(), l), ErrRateLimited)
2054

21-
time.Sleep(6 * time.Second)
55+
time.Sleep(sleepDuration)
56+
2257
for i := 0; i < 3; i++ {
2358
require.NoError(t, r.Check(context.Background(), l))
2459
}
2560
require.ErrorIs(t, r.Check(context.Background(), l), ErrRateLimited)
61+
62+
// wait for goroutine to finish
63+
time.Sleep(time.Second)
2664
}

0 commit comments

Comments
 (0)