Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package identities

import (
"context"
"fmt"
"os"
"testing"
"time"

"github.com/stretchr/testify/require"
unkey "github.com/unkeyed/unkey-go"
"github.com/unkeyed/unkey-go/models/components"
"github.com/unkeyed/unkey-go/models/operations"
attack "github.com/unkeyed/unkey/apps/agent/pkg/testutil"
"github.com/unkeyed/unkey/apps/agent/pkg/uid"
"github.com/unkeyed/unkey/apps/agent/pkg/util"
)

func TestIdentitiesRatelimitAccuracy(t *testing.T) {
// Step 1 --------------------------------------------------------------------
// Setup the sdk, create an API and an identity
// ---------------------------------------------------------------------------

ctx := context.Background()
rootKey := os.Getenv("INTEGRATION_TEST_ROOT_KEY")
require.NotEmpty(t, rootKey, "INTEGRATION_TEST_ROOT_KEY must be set")
baseURL := os.Getenv("UNKEY_BASE_URL")
require.NotEmpty(t, baseURL, "UNKEY_BASE_URL must be set")

sdk := unkey.New(
unkey.WithServerURL(baseURL),
unkey.WithSecurity(rootKey),
)

for _, nKeys := range []int{1} { //, 3, 10, 1000} {
t.Run(fmt.Sprintf("with %d keys", nKeys), func(t *testing.T) {

for _, tc := range []struct {
rate attack.Rate
testDuration time.Duration
}{
{
rate: attack.Rate{Freq: 20, Per: time.Second},
testDuration: 1 * time.Minute,
},
{
rate: attack.Rate{Freq: 100, Per: time.Second},
testDuration: 5 * time.Minute,
},
} {
t.Run(fmt.Sprintf("[%s] over %s", tc.rate.String(), tc.testDuration), func(t *testing.T) {
api, err := sdk.Apis.CreateAPI(ctx, operations.CreateAPIRequestBody{
Name: uid.New("testapi"),
})
require.NoError(t, err)

externalId := uid.New("testuser")

_, err = sdk.Identities.CreateIdentity(ctx, operations.CreateIdentityRequestBody{
ExternalID: externalId,
Meta: map[string]any{
"email": "test@test.com",
},
})
require.NoError(t, err)

// Step 2 --------------------------------------------------------------------
// Update the identity with ratelimits
// ---------------------------------------------------------------------------

inferenceLimit := operations.UpdateIdentityRatelimits{
Name: "inferenceLimit",
Limit: 100,
Duration: time.Minute.Milliseconds(),
}

_, err = sdk.Identities.UpdateIdentity(ctx, operations.UpdateIdentityRequestBody{
ExternalID: unkey.String(externalId),
Ratelimits: []operations.UpdateIdentityRatelimits{inferenceLimit},
})
require.NoError(t, err)

// Step 4 --------------------------------------------------------------------
// Create keys that share the same identity and therefore the same ratelimits
// ---------------------------------------------------------------------------

keys := make([]operations.CreateKeyResponseBody, nKeys)
for i := 0; i < len(keys); i++ {
key, err := sdk.Keys.CreateKey(ctx, operations.CreateKeyRequestBody{
APIID: api.Object.APIID,
ExternalID: unkey.String(externalId),
Environment: unkey.String("integration_test"),
})
require.NoError(t, err)
keys[i] = *key.Object
}

// Step 5 --------------------------------------------------------------------
// Test ratelimits
// ---------------------------------------------------------------------------

total := 0
passed := 0

results := attack.Attack(t, tc.rate, tc.testDuration, func() bool {

// Each request uses one of the keys randomly
key := util.RandomElement(keys).Key

res, err := sdk.Keys.VerifyKey(context.Background(), components.V1KeysVerifyKeyRequest{
APIID: unkey.String(api.Object.APIID),
Key: key,
Ratelimits: []components.Ratelimits{
{Name: inferenceLimit.Name},
},
})
require.NoError(t, err)

return res.V1KeysVerifyKeyResponse.Valid

})

for valid := range results {
total++
if valid {
passed++
}

}

// Step 6 --------------------------------------------------------------------
// Assert ratelimits worked
// ---------------------------------------------------------------------------

exactLimit := int(inferenceLimit.Limit) * int(tc.testDuration/(time.Duration(inferenceLimit.Duration)*time.Millisecond))
upperLimit := int(1.2 * float64(exactLimit))
lowerLimit := exactLimit
if total < lowerLimit {
lowerLimit = total
}
t.Logf("Total: %d, Passed: %d, lowerLimit: %d, exactLimit: %d, upperLimit: %d", total, passed, lowerLimit, exactLimit, upperLimit)

// check requests::api is not exceeded
require.GreaterOrEqual(t, passed, lowerLimit)
require.LessOrEqual(t, passed, upperLimit)
})
}
})
}
}
12 changes: 4 additions & 8 deletions apps/agent/integration/identities/token_ratelimits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,10 @@ func TestClusterRatelimitAccuracy(t *testing.T) {
baseURL := os.Getenv("UNKEY_BASE_URL")
require.NotEmpty(t, baseURL, "UNKEY_BASE_URL must be set")

options := []unkey.SDKOption{
sdk := unkey.New(
unkey.WithServerURL(baseURL),
unkey.WithSecurity(rootKey),
}

if baseURL != "" {
options = append(options, unkey.WithServerURL(baseURL))
}
sdk := unkey.New(options...)
)

api, err := sdk.Apis.CreateAPI(ctx, operations.CreateAPIRequestBody{
Name: uid.New("testapi"),
Expand All @@ -47,7 +43,7 @@ func TestClusterRatelimitAccuracy(t *testing.T) {
_, err = sdk.Identities.CreateIdentity(ctx, operations.CreateIdentityRequestBody{
ExternalID: externalId,
Meta: map[string]any{
"email": "andreas@unkey.dev",
"email": "test@test.com",
},
})
require.NoError(t, err)
Expand Down
125 changes: 125 additions & 0 deletions apps/agent/integration/keys/ratelimits_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package keys_test

import (
"context"
"fmt"
"os"
"testing"
"time"

"github.com/stretchr/testify/require"
unkey "github.com/unkeyed/unkey-go"
"github.com/unkeyed/unkey-go/models/components"
"github.com/unkeyed/unkey-go/models/operations"
attack "github.com/unkeyed/unkey/apps/agent/pkg/testutil"
"github.com/unkeyed/unkey/apps/agent/pkg/uid"
"github.com/unkeyed/unkey/apps/agent/pkg/util"
)

func TestDefaultRatelimitAccuracy(t *testing.T) {
// Step 1 --------------------------------------------------------------------
// Setup the sdk, create an API and a key
// ---------------------------------------------------------------------------

ctx := context.Background()
rootKey := os.Getenv("INTEGRATION_TEST_ROOT_KEY")
require.NotEmpty(t, rootKey, "INTEGRATION_TEST_ROOT_KEY must be set")
baseURL := os.Getenv("UNKEY_BASE_URL")
require.NotEmpty(t, baseURL, "UNKEY_BASE_URL must be set")

options := []unkey.SDKOption{
unkey.WithSecurity(rootKey),
}

if baseURL != "" {
options = append(options, unkey.WithServerURL(baseURL))
}
sdk := unkey.New(options...)

for _, tc := range []struct {
rate attack.Rate
testDuration time.Duration
}{
{
rate: attack.Rate{Freq: 20, Per: time.Second},
testDuration: 1 * time.Minute,
},
{
rate: attack.Rate{Freq: 100, Per: time.Second},
testDuration: 5 * time.Minute,
},
} {
t.Run(fmt.Sprintf("[%s] over %s", tc.rate.String(), tc.testDuration), func(t *testing.T) {
api, err := sdk.Apis.CreateAPI(ctx, operations.CreateAPIRequestBody{
Name: uid.New("testapi"),
})
require.NoError(t, err)

// Step 2 --------------------------------------------------------------------
// Update the identity with ratelimits
// ---------------------------------------------------------------------------

// Step 3 --------------------------------------------------------------------
// Create keys that share the same identity and therefore the same ratelimits
// ---------------------------------------------------------------------------

ratelimit := operations.Ratelimit{
Limit: 100,
Duration: util.Pointer(time.Minute.Milliseconds()),
}

key, err := sdk.Keys.CreateKey(ctx, operations.CreateKeyRequestBody{
APIID: api.Object.APIID,
Ratelimit: &ratelimit,
})
require.NoError(t, err)

// Step 5 --------------------------------------------------------------------
// Test ratelimits
// ---------------------------------------------------------------------------

total := 0
passed := 0

results := attack.Attack(t, tc.rate, tc.testDuration, func() bool {

res, err := sdk.Keys.VerifyKey(context.Background(), components.V1KeysVerifyKeyRequest{
APIID: unkey.String(api.Object.APIID),
Key: key.Object.Key,
Ratelimits: []components.Ratelimits{
{Name: "default"},
},
})
require.NoError(t, err)

return res.V1KeysVerifyKeyResponse.Valid

})

for valid := range results {
total++
if valid {
passed++
}

}

// Step 6 --------------------------------------------------------------------
// Assert ratelimits worked
// ---------------------------------------------------------------------------

exactLimit := int(ratelimit.Limit) * int(tc.testDuration/(time.Duration(*ratelimit.Duration)*time.Millisecond))
upperLimit := int(1.2 * float64(exactLimit))
lowerLimit := exactLimit
if total < lowerLimit {
lowerLimit = total
}
t.Logf("Total: %d, Passed: %d, lowerLimit: %d, exactLimit: %d, upperLimit: %d", total, passed, lowerLimit, exactLimit, upperLimit)

// check requests::api is not exceeded
require.GreaterOrEqual(t, passed, lowerLimit)
require.LessOrEqual(t, passed, upperLimit)
})

}
}
2 changes: 1 addition & 1 deletion apps/agent/pkg/circuitbreaker/lib.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func (cb *CB[Res]) preflight(ctx context.Context) error {
return ErrTripped
}

cb.logger.Info().Str("state", string(cb.state)).Int("requests", cb.requests).Int("maxRequests", cb.config.maxRequests).Msg("circuit breaker state")
cb.logger.Debug().Str("state", string(cb.state)).Int("requests", cb.requests).Int("maxRequests", cb.config.maxRequests).Msg("circuit breaker state")
if cb.state == HalfOpen && cb.requests >= cb.config.maxRequests {
return ErrTooManyRequests
}
Expand Down
1 change: 0 additions & 1 deletion apps/agent/pkg/clickhouse/flush.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ func flush[T any](ctx context.Context, conn ch.Conn, table string, rows []T) err
return fault.Wrap(err, fmsg.With("preparing batch failed"))
}
for _, row := range rows {
fmt.Printf("row: %+v\n", row)
err = batch.AppendStruct(&row)
if err != nil {
return fault.Wrap(err, fmsg.With("appending struct to batch failed"))
Expand Down
69 changes: 69 additions & 0 deletions apps/agent/pkg/testutil/attack.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package attack

import (
"fmt"
"sync"
"testing"
"time"
)

type Rate struct {
Freq int
Per time.Duration
}

func (r Rate) String() string {
return fmt.Sprintf("%d per %s", r.Freq, r.Per)
}

// Attack executes the given function at the given rate for the given duration
// and returns a channel on which the results are sent.
//
// The caller must process the results as they arrive on the channel to avoid
// blocking the worker goroutines.
func Attack[Response any](t *testing.T, rate Rate, duration time.Duration, fn func() Response) <-chan Response {
t.Log("attacking")
wg := sync.WaitGroup{}
workers := 256

ticks := make(chan struct{})
responses := make(chan Response)

totalRequests := rate.Freq * int(duration/rate.Per)
dt := rate.Per / time.Duration(rate.Freq)

wg.Add(totalRequests)

go func() {
for i := 0; i < totalRequests; i++ {
ticks <- struct{}{}
time.Sleep(dt)
}
}()

for i := 0; i < workers; i++ {
go func() {
for range ticks {
responses <- fn()
wg.Done()

}
}()
}

go func() {
wg.Wait()
t.Log("attack done, waiting for responses to be processed")

close(ticks)
pending := len(responses)
for pending > 0 {
t.Logf("waiting for responses to be processed: %d", pending)
time.Sleep(100 * time.Millisecond)
}
close(responses)

}()

return responses
}
Loading