Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go/Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ tasks:
INTEGRATION_TEST: true
SIMULATION_TEST: false
cmds:
- go test -failfast -timeout=60m -shuffle=on ./...
- go test -failfast -timeout=60m -shuffle=on -v ./...

test-unit:
cmds:
Expand Down
7 changes: 4 additions & 3 deletions go/apps/api/cancel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ func TestContextCancellation(t *testing.T) {
// Create a containers instance for database
containers := containers.New(t)
dbDsn, _ := containers.RunMySQL()
_, redisUrl, _ := containers.RunRedis()
// Get free ports for the node
portAllocator := port.New()
httpPort := portAllocator.Get()
Expand All @@ -33,10 +34,10 @@ func TestContextCancellation(t *testing.T) {
Image: "test",
HttpPort: httpPort,
Region: "test-region",
Clock: nil, // Will use real clock
ClusterEnabled: false, // Disable clustering for simpler test
ClusterInstanceID: uid.New(uid.InstancePrefix),
Clock: nil, // Will use real clock
InstanceID: uid.New(uid.InstancePrefix),
LogsColor: false,
RedisUrl: redisUrl,
ClickhouseURL: "",
DatabasePrimary: dbDsn,
DatabaseReadonlyReplica: "",
Expand Down
52 changes: 8 additions & 44 deletions go/apps/api/config.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package api

import (
"github.com/unkeyed/unkey/go/pkg/assert"
"github.com/unkeyed/unkey/go/pkg/clock"
)

type Config struct {

// InstanceID is the unique identifier for this instance of the API server
InstanceID string
// Platform identifies the cloud platform where the node is running (e.g., aws, gcp, hetzner)
Platform string

Expand All @@ -18,38 +20,11 @@ type Config struct {
// Region identifies the geographic region where this node is deployed
Region string

// --- Cluster configuration ---

ClusterEnabled bool

// ClusterInstanceID is the unique identifier for this instance within the cluster
ClusterInstanceID string

// --- Advertise Address configuration ---

// ClusterAdvertiseAddrStatic is a static IP address or hostname for node discovery
ClusterAdvertiseAddrStatic string

// ClusterAdvertiseAddrAwsEcsMetadata enables automatic address discovery using AWS ECS container metadata
ClusterAdvertiseAddrAwsEcsMetadata bool

// ClusterRpcPort is the port used for internal RPC communication between nodes (default: 7071)
ClusterRpcPort int

// ClusterGossipPort is the port used for cluster membership and failure detection (default: 7072)
ClusterGossipPort int

// --- Discovery configuration ---

// ClusterDiscoveryStaticAddrs lists seed node addresses for static cluster configuration
ClusterDiscoveryStaticAddrs []string

// ClusterDiscoveryRedisURL provides a Redis connection string for dynamic cluster discovery
ClusterDiscoveryRedisURL string

// ClusterDiscoveryAwsEcs uses the aws ecs API to find peers
ClusterDiscoveryAwsEcs bool
// RedisUrl is the Redis database connection string
RedisUrl string

// Enable TestMode
TestMode bool
// --- Logs configuration ---

// LogsColor enables ANSI color codes in log output
Expand Down Expand Up @@ -80,17 +55,6 @@ type Config struct {

func (c Config) Validate() error {

if c.ClusterEnabled {
err := assert.All(
assert.NotEmpty(c.ClusterInstanceID, "instance id must not be empty"),
assert.Greater(c.ClusterRpcPort, 0),
assert.Greater(c.ClusterGossipPort, 0),
assert.True(c.ClusterAdvertiseAddrStatic != "" || c.ClusterAdvertiseAddrAwsEcsMetadata),
)
if err != nil {
return err
}
}

// nothing to validate yet
return nil
}
215 changes: 128 additions & 87 deletions go/apps/api/integration/multi_node_ratelimiting/accuracy_test.go
Original file line number Diff line number Diff line change
@@ -1,141 +1,182 @@
package multi_node_ratelimiting_test
package multi_node_ratelimiting

import (
"context"
"fmt"
"math"
"net/http"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/unkeyed/unkey/go/apps/api/integration"
handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_ratelimit_limit"
"github.com/unkeyed/unkey/go/pkg/attack"
"github.com/unkeyed/unkey/go/pkg/clock"
"github.com/unkeyed/unkey/go/pkg/db"
"github.com/unkeyed/unkey/go/pkg/testutil"
"github.com/unkeyed/unkey/go/pkg/uid"
)

func TestAccuracy(t *testing.T) {
func TestRateLimitAccuracy(t *testing.T) {
testutil.SkipUnlessIntegration(t)

// How many nodes to simulate
nodes := []int{1, 3, 27}
nodeCounts := []int{
1, 3, 9,
}

limits := []int64{5, 100}
// Define test matrices for each dimension
limits := []int64{
5,
100,
10000,
}

durations := []time.Duration{
1 * time.Second,
5 * time.Second,
1 * time.Minute,
1 * time.Hour,
24 * time.Hour,
}

// Define load patterns as multipliers of the limit
loadFactors := []float64{0.9, 10.0}
loadFactors := []float64{
0.9, // Slightly below limit
1.0, // At limit
10.0, // Well above limit

testDurations := []time.Duration{
time.Minute,
}

for _, nodeCount := range nodes {
for _, testDuration := range testDurations {
for _, limit := range limits {
for _, duration := range durations {
for _, loadFactor := range loadFactors {
// Number of windows to test (determines test duration)
windowCounts := []int{
100,
}

t.Run(fmt.Sprintf("nodes=%d_test=%s_limit=%d_duration=%s_loadFactor=%f", nodeCount, testDuration, limit, duration, loadFactor), func(t *testing.T) {
for _, nodes := range nodeCounts {
t.Run(fmt.Sprintf("nodes_%d", nodes), func(t *testing.T) {
h := integration.New(t, integration.Config{

NumNodes: nodes,
})

for _, windows := range windowCounts {
for _, limit := range limits {
for _, duration := range durations {
for _, loadFactor := range loadFactors {
t.Run(fmt.Sprintf("windows=%d_limit=%d_duration=%d_load=%.1fx", windows, limit, duration, loadFactor), func(t *testing.T) {

ctx := context.Background()

// Create a namespace
namespaceID := uid.New(uid.RatelimitNamespacePrefix)
namespaceName := uid.New("test")
err := db.Query.InsertRatelimitNamespace(ctx, h.DB.RW(), db.InsertRatelimitNamespaceParams{
ID: namespaceID,
WorkspaceID: h.Resources().UserWorkspace.ID,
Name: namespaceName,
CreatedAt: time.Now().UnixMilli(),
})
require.NoError(t, err)

rootKey := h.Seed.CreateRootKey(ctx, h.Seed.Resources.UserWorkspace.ID, fmt.Sprintf("ratelimit.%s.limit", namespaceID))

headers := http.Header{
"Content-Type": {"application/json"},
"Authorization": {fmt.Sprintf("Bearer %s", rootKey)},
}

testutil.SkipUnlessIntegration(t)
// Create a unique identifier for this test
identifier := uid.New("test")

ctx := context.Background()
// Calculate test parameters
// RPS based on loadFactor and limit/duration
rps := int(math.Ceil(float64(limit) * loadFactor * (1000.0 / float64(duration.Milliseconds()))))

// Setup a cluster with the specified number of nodes
h := integration.New(t, integration.Config{
NumNodes: nodeCount,
})
// Must have at least 1 RPS
if rps < 1 {
rps = 1
}

// Create a namespace
namespaceID := uid.New(uid.RatelimitNamespacePrefix)
namespaceName := uid.New("test")
err := db.Query.InsertRatelimitNamespace(ctx, h.DB.RW(), db.InsertRatelimitNamespaceParams{
ID: namespaceID,
WorkspaceID: h.Resources().UserWorkspace.ID,
Name: namespaceName,
CreatedAt: time.Now().UnixMilli(),
})
require.NoError(t, err)
// Total seconds needed to cover windowCount windows
seconds := int(math.Ceil(float64(windows) * float64(duration.Milliseconds()) / 1000.0))

// Create auth for the test
rootKey := h.Seed.CreateRootKey(ctx, h.Resources().UserWorkspace.ID, fmt.Sprintf("ratelimit.%s.limit", namespaceID))
headers := http.Header{
"Content-Type": {"application/json"},
"Authorization": {fmt.Sprintf("Bearer %s", rootKey)},
}
// Request that will be sent repeatedly
req := handler.Request{
Namespace: namespaceName,
Identifier: identifier,
Limit: limit,
Duration: duration.Milliseconds(),
}

identifier := uid.New("test")
// Calculate number of windows and expected limits
totalRequests := rps * seconds
numWindows := float64(seconds*1000) / float64(duration.Milliseconds())

req := handler.Request{
Namespace: namespaceName,
Limit: limit,
Duration: duration.Milliseconds(),
Identifier: identifier,
}
// Calculate theoretical maximum allowed requests
maxAllowed := numWindows * float64(limit)
maxAllowed = math.Min(maxAllowed, float64(totalRequests))

// Calculate test parameters
// RPS based on loadFactor and limit/duration
rps := int(math.Ceil(float64(limit) * loadFactor * (1000.0 / float64(duration.Milliseconds()))))
// Calculate limits with some tolerance
upperLimit := int(maxAllowed * 1.05)
lowerLimit := int(maxAllowed * 0.95)

// Total seconds needed to cover windowCount windows
// Special case for below-limit scenarios
rpsPerWindow := float64(rps) * (float64(duration.Milliseconds()) / 1000.0)
if rpsPerWindow <= float64(limit) {
// When below limit, we expect all or nearly all to succeed
lowerLimit = int(float64(totalRequests) * 0.95)
upperLimit = totalRequests
}

total := 0
passed := 0
// Cap at total requests
upperLimit = min(upperLimit, totalRequests)

lb := integration.NewLoadbalancer(h)
realStart := time.Now()
// Run load test
clk := clock.NewTestClock(realStart)
simulatedStart := clk.Now()
successCount := 0

errors := atomic.Int64{}
// Calculate interval between requests to achieve desired RPS
interval := time.Second / time.Duration(rps)

results := attack.Attack[integration.TestResponse[handler.Response]](t, attack.Rate{Freq: rps, Per: time.Second}, testDuration, func() integration.TestResponse[handler.Response] {
res, err := integration.CallRandomNode[handler.Request, handler.Response](lb, "POST", "/v2/ratelimit.limit", headers, req)
lb := integration.NewLoadbalancer(h)

if err != nil {
errors.Add(1)
}
return res
})
t.Logf("sending %d requests", totalRequests)
for i := 0; i < totalRequests; i++ {
// Simulate request timing to achieve target RPS
clk.Tick(interval)
//time.Sleep(interval)

require.Less(t, errors.Load(), int64(5))
headers.Set("X-Test-Time", fmt.Sprintf("%d", clk.Now().UnixMilli()))
res, err := integration.CallRandomNode[handler.Request, handler.Response](lb, "POST", "/v2/ratelimit.limit", headers, req)
require.NoError(t, err)
require.NoError(t, err)
require.Equal(t, 200, res.Status, "expected 200 status")

for res := range results {
require.Equal(t, http.StatusOK, res.Status, "expected 200 status, but got:%s", res.RawBody)
total++
if res.Body.Data.Success {
passed++
if res.Body.Data.Success {
successCount++
}
}

}

windows := math.Ceil(float64(testDuration) / float64(duration))
// Calculate theoretical maximum allowed requests
maxAllowed := math.Min(windows*float64(limit), float64(total))

// Calculate limits with some tolerance
upperLimit := int(maxAllowed * 1.2)
lowerLimit := int(math.Min(maxAllowed*0.95, float64(total)))
simulatedDuration := clk.Now().Sub(simulatedStart)
realDuration := time.Since(realStart)

t.Logf("windows: %d, total: %d, passed: %d, acceptable: [%d - %d]", int(windows), total, passed, lowerLimit, upperLimit)
// Verify results
require.GreaterOrEqual(t, passed, lowerLimit,
"Passed count should be >= lower limit")
require.LessOrEqual(t, passed, upperLimit,
"Passed count should be <= upper limit")
t.Logf("Load test simulated %s in %s (%.2f%%)",
simulatedDuration, realDuration, float64(simulatedDuration)/float64(realDuration)*100.0)

t.Logf("balance: %+v", lb.GetMetrics())

})
lbMetrics := lb.GetMetrics()
require.Equal(t, nodes, len(lbMetrics), "all nodes should have received traffic")
// Verify results
require.GreaterOrEqual(t, successCount, lowerLimit,
"Success count should be >= lower limit")
require.LessOrEqual(t, successCount, upperLimit,
"Success count should be <= upper limit")
})
}
}
}
}
}
}
})

}
}
1 change: 1 addition & 0 deletions go/apps/api/routes/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func Register(srv *zen.Server, svc *Services) {
Permissions: svc.Permissions,
RatelimitNamespaceByNameCache: svc.Caches.RatelimitNamespaceByName,
RatelimitOverrideMatchesCache: svc.Caches.RatelimitOverridesMatch,
TestMode: srv.Flags().TestMode,
}),
)
// v2/ratelimit.setOverride
Expand Down
1 change: 0 additions & 1 deletion go/apps/api/routes/v2_ratelimit_limit/accuracy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ func TestRateLimitAccuracy(t *testing.T) {
t.Run(fmt.Sprintf("duration_%dms", duration), func(t *testing.T) {
for _, loadFactor := range loadFactors {
t.Run(fmt.Sprintf("load_%.1fx", loadFactor), func(t *testing.T) {
t.Parallel()
h := testutil.NewHarness(t)

route := handler.New(handler.Services{
Expand Down
Loading
Loading