diff --git a/go/Taskfile.yml b/go/Taskfile.yml index b562d06863..92e87a819f 100644 --- a/go/Taskfile.yml +++ b/go/Taskfile.yml @@ -20,7 +20,7 @@ tasks: INTEGRATION_TEST: true SIMULATION_TEST: false cmds: - - go test -failfast -timeout=60m -shuffle=on ./... + - go test -failfast -timeout=60m -shuffle=on -v ./... test-unit: cmds: diff --git a/go/apps/api/cancel_test.go b/go/apps/api/cancel_test.go index ae999db9b1..77c12fe105 100644 --- a/go/apps/api/cancel_test.go +++ b/go/apps/api/cancel_test.go @@ -20,6 +20,7 @@ func TestContextCancellation(t *testing.T) { // Create a containers instance for database containers := containers.New(t) dbDsn, _ := containers.RunMySQL() + _, redisUrl, _ := containers.RunRedis() // Get free ports for the node portAllocator := port.New() httpPort := portAllocator.Get() @@ -33,10 +34,10 @@ func TestContextCancellation(t *testing.T) { Image: "test", HttpPort: httpPort, Region: "test-region", - Clock: nil, // Will use real clock - ClusterEnabled: false, // Disable clustering for simpler test - ClusterInstanceID: uid.New(uid.InstancePrefix), + Clock: nil, // Will use real clock + InstanceID: uid.New(uid.InstancePrefix), LogsColor: false, + RedisUrl: redisUrl, ClickhouseURL: "", DatabasePrimary: dbDsn, DatabaseReadonlyReplica: "", diff --git a/go/apps/api/config.go b/go/apps/api/config.go index 8de76928f0..83eb017386 100644 --- a/go/apps/api/config.go +++ b/go/apps/api/config.go @@ -1,11 +1,13 @@ package api import ( - "github.com/unkeyed/unkey/go/pkg/assert" "github.com/unkeyed/unkey/go/pkg/clock" ) type Config struct { + + // InstanceID is the unique identifier for this instance of the API server + InstanceID string // Platform identifies the cloud platform where the node is running (e.g., aws, gcp, hetzner) Platform string @@ -18,38 +20,11 @@ type Config struct { // Region identifies the geographic region where this node is deployed Region string - // --- Cluster configuration --- - - ClusterEnabled bool - - // ClusterInstanceID is the unique identifier for this instance within the cluster - ClusterInstanceID string - - // --- Advertise Address configuration --- - - // ClusterAdvertiseAddrStatic is a static IP address or hostname for node discovery - ClusterAdvertiseAddrStatic string - - // ClusterAdvertiseAddrAwsEcsMetadata enables automatic address discovery using AWS ECS container metadata - ClusterAdvertiseAddrAwsEcsMetadata bool - - // ClusterRpcPort is the port used for internal RPC communication between nodes (default: 7071) - ClusterRpcPort int - - // ClusterGossipPort is the port used for cluster membership and failure detection (default: 7072) - ClusterGossipPort int - - // --- Discovery configuration --- - - // ClusterDiscoveryStaticAddrs lists seed node addresses for static cluster configuration - ClusterDiscoveryStaticAddrs []string - - // ClusterDiscoveryRedisURL provides a Redis connection string for dynamic cluster discovery - ClusterDiscoveryRedisURL string - - // ClusterDiscoveryAwsEcs uses the aws ecs API to find peers - ClusterDiscoveryAwsEcs bool + // RedisUrl is the Redis database connection string + RedisUrl string + // Enable TestMode + TestMode bool // --- Logs configuration --- // LogsColor enables ANSI color codes in log output @@ -80,17 +55,6 @@ type Config struct { func (c Config) Validate() error { - if c.ClusterEnabled { - err := assert.All( - assert.NotEmpty(c.ClusterInstanceID, "instance id must not be empty"), - assert.Greater(c.ClusterRpcPort, 0), - assert.Greater(c.ClusterGossipPort, 0), - assert.True(c.ClusterAdvertiseAddrStatic != "" || c.ClusterAdvertiseAddrAwsEcsMetadata), - ) - if err != nil { - return err - } - } - + // nothing to validate yet return nil } diff --git a/go/apps/api/integration/multi_node_ratelimiting/accuracy_test.go b/go/apps/api/integration/multi_node_ratelimiting/accuracy_test.go index e158d179f6..d96f7042ad 100644 --- a/go/apps/api/integration/multi_node_ratelimiting/accuracy_test.go +++ b/go/apps/api/integration/multi_node_ratelimiting/accuracy_test.go @@ -1,141 +1,182 @@ -package multi_node_ratelimiting_test +package multi_node_ratelimiting import ( "context" "fmt" "math" "net/http" - "sync/atomic" "testing" "time" "github.com/stretchr/testify/require" "github.com/unkeyed/unkey/go/apps/api/integration" handler "github.com/unkeyed/unkey/go/apps/api/routes/v2_ratelimit_limit" - "github.com/unkeyed/unkey/go/pkg/attack" + "github.com/unkeyed/unkey/go/pkg/clock" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/testutil" "github.com/unkeyed/unkey/go/pkg/uid" ) -func TestAccuracy(t *testing.T) { +func TestRateLimitAccuracy(t *testing.T) { + testutil.SkipUnlessIntegration(t) - // How many nodes to simulate - nodes := []int{1, 3, 27} + nodeCounts := []int{ + 1, 3, 9, + } - limits := []int64{5, 100} + // Define test matrices for each dimension + limits := []int64{ + 5, + 100, + 10000, + } durations := []time.Duration{ 1 * time.Second, - 5 * time.Second, + 1 * time.Minute, + 1 * time.Hour, + 24 * time.Hour, } // Define load patterns as multipliers of the limit - loadFactors := []float64{0.9, 10.0} + loadFactors := []float64{ + 0.9, // Slightly below limit + 1.0, // At limit + 10.0, // Well above limit - testDurations := []time.Duration{ - time.Minute, } - for _, nodeCount := range nodes { - for _, testDuration := range testDurations { - for _, limit := range limits { - for _, duration := range durations { - for _, loadFactor := range loadFactors { + // Number of windows to test (determines test duration) + windowCounts := []int{ + 100, + } - t.Run(fmt.Sprintf("nodes=%d_test=%s_limit=%d_duration=%s_loadFactor=%f", nodeCount, testDuration, limit, duration, loadFactor), func(t *testing.T) { + for _, nodes := range nodeCounts { + t.Run(fmt.Sprintf("nodes_%d", nodes), func(t *testing.T) { + h := integration.New(t, integration.Config{ + + NumNodes: nodes, + }) + + for _, windows := range windowCounts { + for _, limit := range limits { + for _, duration := range durations { + for _, loadFactor := range loadFactors { + t.Run(fmt.Sprintf("windows=%d_limit=%d_duration=%d_load=%.1fx", windows, limit, duration, loadFactor), func(t *testing.T) { + + ctx := context.Background() + + // Create a namespace + namespaceID := uid.New(uid.RatelimitNamespacePrefix) + namespaceName := uid.New("test") + err := db.Query.InsertRatelimitNamespace(ctx, h.DB.RW(), db.InsertRatelimitNamespaceParams{ + ID: namespaceID, + WorkspaceID: h.Resources().UserWorkspace.ID, + Name: namespaceName, + CreatedAt: time.Now().UnixMilli(), + }) + require.NoError(t, err) + + rootKey := h.Seed.CreateRootKey(ctx, h.Seed.Resources.UserWorkspace.ID, fmt.Sprintf("ratelimit.%s.limit", namespaceID)) + + headers := http.Header{ + "Content-Type": {"application/json"}, + "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, + } - testutil.SkipUnlessIntegration(t) + // Create a unique identifier for this test + identifier := uid.New("test") - ctx := context.Background() + // Calculate test parameters + // RPS based on loadFactor and limit/duration + rps := int(math.Ceil(float64(limit) * loadFactor * (1000.0 / float64(duration.Milliseconds())))) - // Setup a cluster with the specified number of nodes - h := integration.New(t, integration.Config{ - NumNodes: nodeCount, - }) + // Must have at least 1 RPS + if rps < 1 { + rps = 1 + } - // Create a namespace - namespaceID := uid.New(uid.RatelimitNamespacePrefix) - namespaceName := uid.New("test") - err := db.Query.InsertRatelimitNamespace(ctx, h.DB.RW(), db.InsertRatelimitNamespaceParams{ - ID: namespaceID, - WorkspaceID: h.Resources().UserWorkspace.ID, - Name: namespaceName, - CreatedAt: time.Now().UnixMilli(), - }) - require.NoError(t, err) + // Total seconds needed to cover windowCount windows + seconds := int(math.Ceil(float64(windows) * float64(duration.Milliseconds()) / 1000.0)) - // Create auth for the test - rootKey := h.Seed.CreateRootKey(ctx, h.Resources().UserWorkspace.ID, fmt.Sprintf("ratelimit.%s.limit", namespaceID)) - headers := http.Header{ - "Content-Type": {"application/json"}, - "Authorization": {fmt.Sprintf("Bearer %s", rootKey)}, - } + // Request that will be sent repeatedly + req := handler.Request{ + Namespace: namespaceName, + Identifier: identifier, + Limit: limit, + Duration: duration.Milliseconds(), + } - identifier := uid.New("test") + // Calculate number of windows and expected limits + totalRequests := rps * seconds + numWindows := float64(seconds*1000) / float64(duration.Milliseconds()) - req := handler.Request{ - Namespace: namespaceName, - Limit: limit, - Duration: duration.Milliseconds(), - Identifier: identifier, - } + // Calculate theoretical maximum allowed requests + maxAllowed := numWindows * float64(limit) + maxAllowed = math.Min(maxAllowed, float64(totalRequests)) - // Calculate test parameters - // RPS based on loadFactor and limit/duration - rps := int(math.Ceil(float64(limit) * loadFactor * (1000.0 / float64(duration.Milliseconds())))) + // Calculate limits with some tolerance + upperLimit := int(maxAllowed * 1.05) + lowerLimit := int(maxAllowed * 0.95) - // Total seconds needed to cover windowCount windows + // Special case for below-limit scenarios + rpsPerWindow := float64(rps) * (float64(duration.Milliseconds()) / 1000.0) + if rpsPerWindow <= float64(limit) { + // When below limit, we expect all or nearly all to succeed + lowerLimit = int(float64(totalRequests) * 0.95) + upperLimit = totalRequests + } - total := 0 - passed := 0 + // Cap at total requests + upperLimit = min(upperLimit, totalRequests) - lb := integration.NewLoadbalancer(h) + realStart := time.Now() + // Run load test + clk := clock.NewTestClock(realStart) + simulatedStart := clk.Now() + successCount := 0 - errors := atomic.Int64{} + // Calculate interval between requests to achieve desired RPS + interval := time.Second / time.Duration(rps) - results := attack.Attack[integration.TestResponse[handler.Response]](t, attack.Rate{Freq: rps, Per: time.Second}, testDuration, func() integration.TestResponse[handler.Response] { - res, err := integration.CallRandomNode[handler.Request, handler.Response](lb, "POST", "/v2/ratelimit.limit", headers, req) + lb := integration.NewLoadbalancer(h) - if err != nil { - errors.Add(1) - } - return res - }) + t.Logf("sending %d requests", totalRequests) + for i := 0; i < totalRequests; i++ { + // Simulate request timing to achieve target RPS + clk.Tick(interval) + //time.Sleep(interval) - require.Less(t, errors.Load(), int64(5)) + headers.Set("X-Test-Time", fmt.Sprintf("%d", clk.Now().UnixMilli())) + res, err := integration.CallRandomNode[handler.Request, handler.Response](lb, "POST", "/v2/ratelimit.limit", headers, req) + require.NoError(t, err) + require.NoError(t, err) + require.Equal(t, 200, res.Status, "expected 200 status") - for res := range results { - require.Equal(t, http.StatusOK, res.Status, "expected 200 status, but got:%s", res.RawBody) - total++ - if res.Body.Data.Success { - passed++ + if res.Body.Data.Success { + successCount++ + } } - } - - windows := math.Ceil(float64(testDuration) / float64(duration)) - // Calculate theoretical maximum allowed requests - maxAllowed := math.Min(windows*float64(limit), float64(total)) - - // Calculate limits with some tolerance - upperLimit := int(maxAllowed * 1.2) - lowerLimit := int(math.Min(maxAllowed*0.95, float64(total))) + simulatedDuration := clk.Now().Sub(simulatedStart) + realDuration := time.Since(realStart) - t.Logf("windows: %d, total: %d, passed: %d, acceptable: [%d - %d]", int(windows), total, passed, lowerLimit, upperLimit) - // Verify results - require.GreaterOrEqual(t, passed, lowerLimit, - "Passed count should be >= lower limit") - require.LessOrEqual(t, passed, upperLimit, - "Passed count should be <= upper limit") + t.Logf("Load test simulated %s in %s (%.2f%%)", + simulatedDuration, realDuration, float64(simulatedDuration)/float64(realDuration)*100.0) - t.Logf("balance: %+v", lb.GetMetrics()) - - }) + lbMetrics := lb.GetMetrics() + require.Equal(t, nodes, len(lbMetrics), "all nodes should have received traffic") + // Verify results + require.GreaterOrEqual(t, successCount, lowerLimit, + "Success count should be >= lower limit") + require.LessOrEqual(t, successCount, upperLimit, + "Success count should be <= upper limit") + }) + } } } } - } - } + }) + } } diff --git a/go/apps/api/routes/register.go b/go/apps/api/routes/register.go index a3fd9b57f1..8d78c51ce2 100644 --- a/go/apps/api/routes/register.go +++ b/go/apps/api/routes/register.go @@ -52,6 +52,7 @@ func Register(srv *zen.Server, svc *Services) { Permissions: svc.Permissions, RatelimitNamespaceByNameCache: svc.Caches.RatelimitNamespaceByName, RatelimitOverrideMatchesCache: svc.Caches.RatelimitOverridesMatch, + TestMode: srv.Flags().TestMode, }), ) // v2/ratelimit.setOverride diff --git a/go/apps/api/routes/v2_ratelimit_limit/accuracy_test.go b/go/apps/api/routes/v2_ratelimit_limit/accuracy_test.go index adfee2ac41..0d2a2680ee 100644 --- a/go/apps/api/routes/v2_ratelimit_limit/accuracy_test.go +++ b/go/apps/api/routes/v2_ratelimit_limit/accuracy_test.go @@ -54,7 +54,6 @@ func TestRateLimitAccuracy(t *testing.T) { t.Run(fmt.Sprintf("duration_%dms", duration), func(t *testing.T) { for _, loadFactor := range loadFactors { t.Run(fmt.Sprintf("load_%.1fx", loadFactor), func(t *testing.T) { - t.Parallel() h := testutil.NewHarness(t) route := handler.New(handler.Services{ diff --git a/go/apps/api/routes/v2_ratelimit_limit/handler.go b/go/apps/api/routes/v2_ratelimit_limit/handler.go index e4b8bedf75..784d2f9fa5 100644 --- a/go/apps/api/routes/v2_ratelimit_limit/handler.go +++ b/go/apps/api/routes/v2_ratelimit_limit/handler.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "net/http" + "strconv" "time" "github.com/unkeyed/unkey/go/apps/api/openapi" @@ -34,6 +35,7 @@ type Services struct { Ratelimit ratelimit.Service RatelimitNamespaceByNameCache cache.Cache[db.FindRatelimitNamespaceByNameParams, db.RatelimitNamespace] RatelimitOverrideMatchesCache cache.Cache[db.FindRatelimitOverrideMatchesParams, []db.RatelimitOverride] + TestMode bool } // New creates a new route handler for ratelimits.limit @@ -173,6 +175,18 @@ func New(svc Services) zen.Route { Duration: time.Duration(duration) * time.Millisecond, Limit: limit, Cost: cost, + Time: time.Time{}, + } + if svc.TestMode { + header := s.Request().Header.Get("X-Test-Time") + if header != "" { + i, parseErr := strconv.ParseInt(header, 10, 64) + if parseErr != nil { + svc.Logger.Warn("invalid test time", "header", header) + } else { + limitReq.Time = time.UnixMilli(i) + } + } } result, err := svc.Ratelimit.Ratelimit(ctx, limitReq) @@ -202,7 +216,7 @@ func New(svc Services) zen.Route { Success: result.Success, Limit: limit, Remaining: result.Remaining, - Reset: result.Reset, + Reset: result.Reset.UnixMilli(), OverrideId: nil, }, } diff --git a/go/apps/api/run.go b/go/apps/api/run.go index ed91280c75..8ee891b4d6 100644 --- a/go/apps/api/run.go +++ b/go/apps/api/run.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "log/slog" - "net" "os" "os/signal" "runtime/debug" @@ -16,17 +15,13 @@ import ( "github.com/unkeyed/unkey/go/internal/services/keys" "github.com/unkeyed/unkey/go/internal/services/permissions" "github.com/unkeyed/unkey/go/internal/services/ratelimit" - "github.com/unkeyed/unkey/go/pkg/aws/ecs" "github.com/unkeyed/unkey/go/pkg/clickhouse" "github.com/unkeyed/unkey/go/pkg/clock" - "github.com/unkeyed/unkey/go/pkg/cluster" + "github.com/unkeyed/unkey/go/pkg/counter" "github.com/unkeyed/unkey/go/pkg/db" - "github.com/unkeyed/unkey/go/pkg/discovery" - "github.com/unkeyed/unkey/go/pkg/membership" "github.com/unkeyed/unkey/go/pkg/otel" "github.com/unkeyed/unkey/go/pkg/otel/logging" "github.com/unkeyed/unkey/go/pkg/prometheus" - "github.com/unkeyed/unkey/go/pkg/rpc" "github.com/unkeyed/unkey/go/pkg/shutdown" "github.com/unkeyed/unkey/go/pkg/version" "github.com/unkeyed/unkey/go/pkg/zen" @@ -49,7 +44,7 @@ func Run(ctx context.Context, cfg Config) error { grafanaErr := otel.InitGrafana(ctx, otel.Config{ Application: "api", Version: version.Version, - InstanceID: cfg.ClusterInstanceID, + InstanceID: cfg.InstanceID, CloudRegion: cfg.Region, TraceSampleRate: cfg.OtelTraceSamplingRate, }, @@ -61,8 +56,8 @@ func Run(ctx context.Context, cfg Config) error { } logger := logging.New() - if cfg.ClusterInstanceID != "" { - logger = logger.With(slog.String("instanceID", cfg.ClusterInstanceID)) + if cfg.InstanceID != "" { + logger = logger.With(slog.String("instanceID", cfg.InstanceID)) } if cfg.Platform != "" { logger = logger.With(slog.String("platform", cfg.Platform)) @@ -74,6 +69,11 @@ func Run(ctx context.Context, cfg Config) error { logger = logger.With(slog.String("version", version.Version)) } + if cfg.TestMode { + logger = logger.With("testmode", true) + logger.Warn("TESTMODE IS ENABLED. This is not secure in production!") + } + // Catch any panics now after we have a logger but before we start the server defer func() { if r := recover(); r != nil { @@ -95,15 +95,9 @@ func Run(ctx context.Context, cfg Config) error { defer db.Close() - d, err := setupDiscovery(cfg, logger, shutdowns) - if err != nil { - return fmt.Errorf("unable to create service discovery: %w", err) - } - if cfg.PrometheusPort > 0 { prom, promErr := prometheus.New(prometheus.Config{ - Discovery: d, - Logger: logger, + Logger: logger, }) if promErr != nil { return fmt.Errorf("unable to start prometheus: %w", promErr) @@ -116,11 +110,6 @@ func Run(ctx context.Context, cfg Config) error { }() } - c, err := setupCluster(cfg, logger, d, shutdowns) - if err != nil { - return fmt.Errorf("unable to create cluster: %w", err) - } - var ch clickhouse.ClickHouse = clickhouse.NewNoop() if cfg.ClickhouseURL != "" { ch, err = clickhouse.New(clickhouse.Config{ @@ -141,8 +130,11 @@ func Run(ctx context.Context, cfg Config) error { } srv, err := zen.New(zen.Config{ - InstanceID: cfg.ClusterInstanceID, + InstanceID: cfg.InstanceID, Logger: logger, + Flags: &zen.Flags{ + TestMode: cfg.TestMode, + }, }) if err != nil { return fmt.Errorf("unable to create server: %w", err) @@ -165,33 +157,23 @@ func Run(ctx context.Context, cfg Config) error { return fmt.Errorf("unable to create key service: %w", err) } + ctr, err := counter.NewRedis(counter.RedisConfig{ + RedisURL: cfg.RedisUrl, + Logger: logger, + }) + if err != nil { + return fmt.Errorf("unable to create counter: %w", err) + } + rlSvc, err := ratelimit.New(ratelimit.Config{ Logger: logger, - Cluster: c, Clock: clk, + Counter: ctr, }) if err != nil { return fmt.Errorf("unable to create ratelimit service: %w", err) } - if cfg.ClusterEnabled { - - rpcSvc, rpcErr := rpc.New(rpc.Config{ - Logger: logger, - RatelimitService: rlSvc, - }) - if rpcErr != nil { - return fmt.Errorf("unable to create rpc service: %w", rpcErr) - } - - go func() { - listenErr := rpcSvc.Listen(ctx, fmt.Sprintf(":%d", cfg.ClusterRpcPort)) - if listenErr != nil { - panic(listenErr) - } - }() - } - p, err := permissions.New(permissions.Config{ DB: db, Logger: logger, @@ -248,110 +230,3 @@ func gracefulShutdown(ctx context.Context, logger logging.Logger, shutdowns *shu } return nil } - -func setupDiscovery(cfg Config, logger logging.Logger, shutdowns *shutdown.Shutdowns) (discovery.Discoverer, error) { - - if cfg.ClusterDiscoveryRedisURL != "" { - advertiseAddr, err := getAdvertiseAddr(cfg) - if err != nil { - return nil, err - } - d, err := discovery.NewRedis(discovery.RedisConfig{ - URL: cfg.ClusterDiscoveryRedisURL, - InstanceID: cfg.ClusterInstanceID, - Addr: advertiseAddr, - Logger: logger, - }) - if err != nil { - return nil, fmt.Errorf("unable to create redis discovery: %w", err) - } - shutdowns.RegisterCtx(d.Shutdown) - return d, nil - } else if cfg.ClusterDiscoveryAwsEcs { - d, err := discovery.NewAwsEcs(discovery.AwsEcsConfig{ - Region: cfg.Region, - Logger: logger, - }) - if err != nil { - return nil, fmt.Errorf("unable to create aws ecs discovery: %w", err) - } - return d, nil - } - return &discovery.Static{ - Addrs: cfg.ClusterDiscoveryStaticAddrs, - }, nil - -} - -func getAdvertiseAddr(cfg Config) (string, error) { - - switch { - case cfg.ClusterAdvertiseAddrStatic != "": - { - - hosts, err := net.LookupHost(cfg.ClusterAdvertiseAddrStatic) - if err != nil { - return "", err - } - if len(hosts) == 0 { - return "", err - } - - return hosts[0], nil - } - case cfg.ClusterAdvertiseAddrAwsEcsMetadata: - { - addr, err := ecs.GetPrivateDnsName() - if err != nil { - return "", fmt.Errorf("unable to get private dns name: %w", err) - } - return addr, nil - - } - - default: - return "", fmt.Errorf("invalid advertise address configuration: %+v", cfg) - } -} -func setupCluster(cfg Config, logger logging.Logger, d discovery.Discoverer, shutdowns *shutdown.Shutdowns) (cluster.Cluster, error) { - if !cfg.ClusterEnabled { - return cluster.NewNoop("", "127.0.0.1"), nil - } - - advertiseAddr, err := getAdvertiseAddr(cfg) - if err != nil { - return nil, err - } - - m, err := membership.New(membership.Config{ - InstanceID: cfg.ClusterInstanceID, - AdvertiseHost: advertiseAddr, - GossipPort: cfg.ClusterGossipPort, - RpcPort: cfg.ClusterRpcPort, - HttpPort: cfg.HttpPort, - Logger: logger, - }) - if err != nil { - return nil, fmt.Errorf("unable to create membership: %w", err) - } - - c, err := cluster.New(cluster.Config{ - Self: cluster.Instance{ - ID: cfg.ClusterInstanceID, - RpcAddr: fmt.Sprintf("%s:%d", advertiseAddr, cfg.ClusterRpcPort), - }, - Logger: logger, - Membership: m, - }) - if err != nil { - return nil, fmt.Errorf("unable to create cluster: %w", err) - } - shutdowns.RegisterCtx(c.Shutdown) - - err = m.Start(d) - if err != nil { - return nil, fmt.Errorf("unable to start membership: %w", err) - } - - return c, nil -} diff --git a/go/apps/api/run_test.go b/go/apps/api/run_test.go deleted file mode 100644 index bf5abb04c0..0000000000 --- a/go/apps/api/run_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package api_test - -import ( - "context" - "fmt" - "net/http" - "testing" - "time" - - "github.com/stretchr/testify/require" - "github.com/unkeyed/unkey/go/apps/api" - "github.com/unkeyed/unkey/go/pkg/port" - "github.com/unkeyed/unkey/go/pkg/testutil" - "github.com/unkeyed/unkey/go/pkg/testutil/containers" - "github.com/unkeyed/unkey/go/pkg/uid" -) - -// TestClusterFormation verifies that a cluster of API nodes can successfully form -// and communicate with each other. -func TestClusterFormation(t *testing.T) { - testutil.SkipUnlessIntegration(t) - - // Create a containers instance for database - containers := containers.New(t) - dbDsn, _ := containers.RunMySQL() - - // Get free ports for the nodes - portAllocator := port.New() - joinAddrs := []string{} - - // Start each node in a separate goroutine - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - clusterSize := 3 - for i := 0; i < clusterSize; i++ { - - instanceID := uid.New(uid.InstancePrefix) - gossipPort := portAllocator.Get() - config := api.Config{ - Platform: "test", - Image: "test", - HttpPort: portAllocator.Get(), - Region: "test-region", - Clock: nil, // Will use real clock - ClusterEnabled: true, - ClusterInstanceID: instanceID, - ClusterAdvertiseAddrStatic: "localhost", - ClusterRpcPort: portAllocator.Get(), - ClusterGossipPort: gossipPort, - ClusterDiscoveryStaticAddrs: joinAddrs, - LogsColor: false, - ClickhouseURL: "", - DatabasePrimary: dbDsn, - DatabaseReadonlyReplica: "", - OtelEnabled: false, - } - - joinAddrs = append(joinAddrs, fmt.Sprintf("localhost:%d", gossipPort)) - - go func() { - require.NoError(t, api.Run(ctx, config)) - }() - - require.Eventually(t, func() bool { - - res, err := http.Get(fmt.Sprintf("http://localhost:%d/v2/liveness", config.HttpPort)) - if err != nil { - return false - } - require.NoError(t, res.Body.Close()) - - return res.StatusCode == http.StatusOK - - }, time.Second*10, time.Millisecond*100) - - } - - t.Log("All nodes started successfully") - - // Now verify cluster formation by checking cluster status endpoints - // Give the cluster a moment to form - time.Sleep(5 * time.Second) - - // Clean up - cancel() // Signal all nodes to shut down - time.Sleep(2 * time.Second) // Give them time to shut down -} diff --git a/go/cmd/api/main.go b/go/cmd/api/main.go index e56e32bd87..6b6f2c474c 100644 --- a/go/cmd/api/main.go +++ b/go/cmd/api/main.go @@ -68,164 +68,19 @@ Examples: Required: false, }, - // Cluster configuration - &cli.BoolFlag{ - Name: "cluster", - Usage: `Enable cluster mode to connect multiple Unkey API nodes together. -When enabled, this node will attempt to form or join a cluster with other Unkey nodes. -Clustering provides high availability, load distribution, and consistent rate limiting across nodes. - -For production deployments with multiple instances, set this to true. -For single-node setups (local development, small deployments), leave this disabled. - -When clustering is enabled, you must also configure: -1. An address advertisement method (static or AWS ECS metadata) -2. A discovery method (static addresses or Redis) -3. Appropriate ports for RPC and gossip protocols - -Examples: - --cluster=true # Enable clustering - --cluster=false # Disable clustering (default)`, - Sources: cli.EnvVars("UNKEY_CLUSTER"), - Required: false, - }, &cli.StringFlag{ - Name: "cluster-instance-id", - Usage: `Unique identifier for this instance within the cluster. -Every instance in a cluster must have a unique identifier. This ID is used in logs, -metrics, and for node-to-node communication within the cluster. - -If not specified, a random UUID with 'node_' prefix will be automatically generated. -For ephemeral nodes (like in auto-scaling groups), automatic generation is appropriate. -For stable deployments, consider setting this to a persistent value tied to the instance. - -Examples: - --cluster-instance-id=instance_east1_001 # For a instance in East region, instance 001 - --cluster-instance-id=instance_replica2 # For a second replica instance - --cluster-instance-id=instance_dev_local # For local development`, - Sources: cli.EnvVars("UNKEY_CLUSTER_NODE_ID"), + Name: "instance-id", + Usage: "Unique identifier for this instance within the cluster.", + Sources: cli.EnvVars("UNKEY_INSTANCE_ID"), Value: uid.New(uid.InstancePrefix), Required: false, }, - &cli.StringFlag{ - Name: "cluster-advertise-addr-static", - Usage: `Static IP address or hostname that other nodes can use to connect to this node. -This is required for clustering when not using AWS ECS discovery. -The address must be reachable by all other nodes in the cluster. - -For on-premises or static cloud deployments, use a fixed IP address or DNS name. -In Kubernetes environments, this could be the pod's DNS name within the cluster. - -Only one advertisement method should be configured - either static or AWS ECS metadata. - -Examples: - --cluster-advertise-addr-static=10.0.1.5 # Direct IP address - --cluster-advertise-addr-static=node1.unkey.internal # DNS name - --cluster-advertise-addr-static=unkey-0.unkey-headless.default.svc.cluster.local # Kubernetes DNS`, - Sources: cli.EnvVars("UNKEY_CLUSTER_ADVERTISE_ADDR_STATIC", "HOSTNAME"), - Required: false, - }, - &cli.BoolFlag{ - Name: "cluster-advertise-addr-aws-ecs-metadata", - Usage: `Enable automatic address discovery using AWS ECS container metadata. -When running on AWS ECS, this flag allows the container to automatically determine -its private DNS name from the ECS metadata service. This simplifies cluster configuration -in AWS ECS deployments with dynamic IP assignments. - -Only one advertisement method should be configured - either static or AWS ECS metadata. -Do not set cluster-advertise-addr-static if this option is enabled. - -This option is specifically designed for AWS ECS and won't work in other environments. - -Examples: - --cluster-advertise-addr-aws-ecs-metadata=true # Enable AWS ECS metadata-based discovery - --cluster-advertise-addr-aws-ecs-metadata=false # Disable (default)`, - Sources: cli.EnvVars("UNKEY_CLUSTER_ADVERTISE_ADDR_AWS_ECS_METADATA"), - Required: false, - }, - &cli.IntFlag{ - Name: "cluster-rpc-port", - Usage: `Port used for internal RPC communication between cluster nodes. -This port is used for direct node-to-node communication within the cluster for -operations like distributed rate limiting and state synchronization. - -The port must be accessible by all other nodes in the cluster and should be -different from the HTTP and gossip ports to avoid conflicts. - -In containerized environments, ensure this port is properly exposed between containers. -For security, this port should typically not be exposed to external networks. - -Examples: - --cluster-rpc-port=7071 # Default RPC port`, - Sources: cli.EnvVars("UNKEY_CLUSTER_RPC_PORT"), - Value: 7071, - Required: false, - }, - &cli.IntFlag{ - Name: "cluster-gossip-port", - Usage: `Port used for cluster membership and failure detection via gossip protocol. -The gossip protocol is used to maintain cluster membership, detect node failures, -and distribute information about the cluster state. - -This port must be accessible by all other nodes in the cluster and should be -different from the HTTP and RPC ports to avoid conflicts. - -In containerized environments, ensure this port is properly exposed between containers. -For security, this port should typically not be exposed to external networks. - -Examples: - --cluster-gossip-port=7072 # Default gossip port`, - Sources: cli.EnvVars("UNKEY_CLUSTER_GOSSIP_PORT"), - Value: 7072, - Required: false, - }, - // Discovery configuration - static - &cli.StringSliceFlag{ - Name: "cluster-discovery-static-addrs", - Usage: `List of seed node addresses for static cluster configuration. -When using static discovery, these addresses serve as initial contact points for -joining the cluster. At least one functioning node address must be provided for -initial cluster formation. - -This flag is required for clustering when not using Redis discovery. -Each address should be a hostname or IP address that's reachable by this node. -It's not necessary to list all nodes - just enough to ensure reliable discovery. -Nodes will automatically discover the full cluster membership after connecting to -any existing cluster member. - -Examples: - --cluster-discovery-static-addrs=10.0.1.5,10.0.1.6 - --cluster-discovery-static-addrs=node1.unkey.internal,node2.unkey.internal - --cluster-discovery-static-addrs=unkey-0.unkey-headless.default.svc.cluster.local`, - Sources: cli.EnvVars("UNKEY_CLUSTER_DISCOVERY_STATIC_ADDRS"), - Required: false, - }, - &cli.BoolFlag{ - Name: "cluster-discovery-aws-ecs", - Usage: `Use the AWS ECS API to find peers within the same cluster.`, - Sources: cli.EnvVars("UNKEY_CLUSTER_DISCOVERY_AWS_ECS"), - Required: false, - }, - // Discovery configuration - Redis + // Redis &cli.StringFlag{ - Name: "cluster-discovery-redis-url", - Usage: `Redis connection string for dynamic cluster discovery. -Redis-based discovery enables nodes to register themselves and discover other nodes -through a shared Redis instance. This is recommended for dynamic environments where -nodes may come and go frequently, such as auto-scaling groups in AWS ECS. - -When specified, nodes will register themselves in Redis and discover other nodes -automatically. This eliminates the need for static address configuration. - -The Redis instance should be accessible by all nodes in the cluster and have -low latency to ensure timely node discovery. - -Examples: - --cluster-discovery-redis-url=redis://localhost:6379/0 - --cluster-discovery-redis-url=redis://user:password@redis.example.com:6379/0 - --cluster-discovery-redis-url=redis://user:password@redis-master.default.svc.cluster.local:6379/0?tls=true`, - Sources: cli.EnvVars("UNKEY_CLUSTER_DISCOVERY_REDIS_URL"), + Name: "redis-url", + Usage: "Redis connection string for cross-cluster semi-durable storage of counters.", + Sources: cli.EnvVars("UNKEY_REDIS_URL"), Required: false, }, // Logs configuration @@ -357,6 +212,17 @@ Default: disabled Value: 0, Required: false, }, + &cli.BoolFlag{ + Name: "test-mode", + Usage: `Enable test mode. This is potentially unsafe. +Testmode enables some flags for testing purposes and may trust client inputs blindly. + +Default: disabled + `, + Sources: cli.EnvVars("UNKEY_TEST_MODE"), + Value: false, + Required: false, + }, }, Action: action, @@ -385,18 +251,11 @@ func action(ctx context.Context, cmd *cli.Command) error { OtelEnabled: cmd.Bool("otel"), OtelTraceSamplingRate: cmd.Float("otel-trace-sampling-rate"), - // Cluster - ClusterEnabled: cmd.Bool("cluster"), - ClusterInstanceID: cmd.String("cluster-instance-id"), - ClusterRpcPort: int(cmd.Int("cluster-rpc-port")), - ClusterGossipPort: int(cmd.Int("cluster-gossip-port")), - ClusterAdvertiseAddrStatic: cmd.String("cluster-advertise-addr-static"), - ClusterAdvertiseAddrAwsEcsMetadata: cmd.Bool("cluster-advertise-addr-aws-ecs-metadata"), - ClusterDiscoveryStaticAddrs: cmd.StringSlice("cluster-discovery-static-addrs"), - ClusterDiscoveryAwsEcs: cmd.Bool("cluster-discovery-aws-ecs"), - ClusterDiscoveryRedisURL: cmd.String("cluster-discovery-redis-url"), - PrometheusPort: int(cmd.Int("prometheus-port")), - Clock: clock.New(), + InstanceID: cmd.String("instance-id"), + RedisUrl: cmd.String("redis-url"), + PrometheusPort: int(cmd.Int("prometheus-port")), + Clock: clock.New(), + TestMode: cmd.Bool("test-mode"), } err := config.Validate() diff --git a/go/gen/proto/ratelimit/v1/ratelimitv1connect/service.connect.go b/go/gen/proto/ratelimit/v1/ratelimitv1connect/service.connect.go deleted file mode 100644 index db77fe5777..0000000000 --- a/go/gen/proto/ratelimit/v1/ratelimitv1connect/service.connect.go +++ /dev/null @@ -1,146 +0,0 @@ -// Code generated by protoc-gen-connect-go. DO NOT EDIT. -// -// Source: proto/ratelimit/v1/service.proto - -package ratelimitv1connect - -import ( - connect "connectrpc.com/connect" - context "context" - errors "errors" - v1 "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1" - http "net/http" - strings "strings" -) - -// This is a compile-time assertion to ensure that this generated file and the connect package are -// compatible. If you get a compiler error that this constant is not defined, this code was -// generated with a version of connect newer than the one compiled into your binary. You can fix the -// problem by either regenerating this code with an older version of connect or updating the connect -// version compiled into your binary. -const _ = connect.IsAtLeastVersion1_13_0 - -const ( - // RatelimitServiceName is the fully-qualified name of the RatelimitService service. - RatelimitServiceName = "ratelimit.v1.RatelimitService" -) - -// These constants are the fully-qualified names of the RPCs defined in this package. They're -// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route. -// -// Note that these are different from the fully-qualified method names used by -// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to -// reflection-formatted method names, remove the leading slash and convert the remaining slash to a -// period. -const ( - // RatelimitServiceReplayProcedure is the fully-qualified name of the RatelimitService's Replay RPC. - RatelimitServiceReplayProcedure = "/ratelimit.v1.RatelimitService/Replay" -) - -// These variables are the protoreflect.Descriptor objects for the RPCs defined in this package. -var ( - ratelimitServiceServiceDescriptor = v1.File_proto_ratelimit_v1_service_proto.Services().ByName("RatelimitService") - ratelimitServiceReplayMethodDescriptor = ratelimitServiceServiceDescriptor.Methods().ByName("Replay") -) - -// RatelimitServiceClient is a client for the ratelimit.v1.RatelimitService service. -type RatelimitServiceClient interface { - // Replay synchronizes rate limit state between nodes using consistent hashing. - // - // Key behaviors: - // - Each identifier maps to exactly one origin server via consistent hashing - // - Edge nodes replay their local rate limit decisions to the origin - // - Origin maintains the source of truth for rate limit state - // - Edge nodes must update their state based on origin responses - // - // Flow: - // 1. Edge node receives rate limit request - // 2. Edge makes local decision (may be defensive) - // 3. Edge replays decision to origin - // 4. Origin processes and returns authoritative state - // 5. Edge updates local state and returns result to client - // - // This approach ensures eventual consistency while allowing for - // fast local decisions at the edge. - Replay(context.Context, *connect.Request[v1.ReplayRequest]) (*connect.Response[v1.ReplayResponse], error) -} - -// NewRatelimitServiceClient constructs a client for the ratelimit.v1.RatelimitService service. By -// default, it uses the Connect protocol with the binary Protobuf Codec, asks for gzipped responses, -// and sends uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the -// connect.WithGRPC() or connect.WithGRPCWeb() options. -// -// The URL supplied here should be the base URL for the Connect or gRPC server (for example, -// http://api.acme.com or https://acme.com/grpc). -func NewRatelimitServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) RatelimitServiceClient { - baseURL = strings.TrimRight(baseURL, "/") - return &ratelimitServiceClient{ - replay: connect.NewClient[v1.ReplayRequest, v1.ReplayResponse]( - httpClient, - baseURL+RatelimitServiceReplayProcedure, - connect.WithSchema(ratelimitServiceReplayMethodDescriptor), - connect.WithClientOptions(opts...), - ), - } -} - -// ratelimitServiceClient implements RatelimitServiceClient. -type ratelimitServiceClient struct { - replay *connect.Client[v1.ReplayRequest, v1.ReplayResponse] -} - -// Replay calls ratelimit.v1.RatelimitService.Replay. -func (c *ratelimitServiceClient) Replay(ctx context.Context, req *connect.Request[v1.ReplayRequest]) (*connect.Response[v1.ReplayResponse], error) { - return c.replay.CallUnary(ctx, req) -} - -// RatelimitServiceHandler is an implementation of the ratelimit.v1.RatelimitService service. -type RatelimitServiceHandler interface { - // Replay synchronizes rate limit state between nodes using consistent hashing. - // - // Key behaviors: - // - Each identifier maps to exactly one origin server via consistent hashing - // - Edge nodes replay their local rate limit decisions to the origin - // - Origin maintains the source of truth for rate limit state - // - Edge nodes must update their state based on origin responses - // - // Flow: - // 1. Edge node receives rate limit request - // 2. Edge makes local decision (may be defensive) - // 3. Edge replays decision to origin - // 4. Origin processes and returns authoritative state - // 5. Edge updates local state and returns result to client - // - // This approach ensures eventual consistency while allowing for - // fast local decisions at the edge. - Replay(context.Context, *connect.Request[v1.ReplayRequest]) (*connect.Response[v1.ReplayResponse], error) -} - -// NewRatelimitServiceHandler builds an HTTP handler from the service implementation. It returns the -// path on which to mount the handler and the handler itself. -// -// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf -// and JSON codecs. They also support gzip compression. -func NewRatelimitServiceHandler(svc RatelimitServiceHandler, opts ...connect.HandlerOption) (string, http.Handler) { - ratelimitServiceReplayHandler := connect.NewUnaryHandler( - RatelimitServiceReplayProcedure, - svc.Replay, - connect.WithSchema(ratelimitServiceReplayMethodDescriptor), - connect.WithHandlerOptions(opts...), - ) - return "/ratelimit.v1.RatelimitService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case RatelimitServiceReplayProcedure: - ratelimitServiceReplayHandler.ServeHTTP(w, r) - default: - http.NotFound(w, r) - } - }) -} - -// UnimplementedRatelimitServiceHandler returns CodeUnimplemented from all methods. -type UnimplementedRatelimitServiceHandler struct{} - -func (UnimplementedRatelimitServiceHandler) Replay(context.Context, *connect.Request[v1.ReplayRequest]) (*connect.Response[v1.ReplayResponse], error) { - return nil, connect.NewError(connect.CodeUnimplemented, errors.New("ratelimit.v1.RatelimitService.Replay is not implemented")) -} diff --git a/go/gen/proto/ratelimit/v1/service.pb.go b/go/gen/proto/ratelimit/v1/service.pb.go deleted file mode 100644 index ab3f28faa5..0000000000 --- a/go/gen/proto/ratelimit/v1/service.pb.go +++ /dev/null @@ -1,529 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.36.6 -// protoc (unknown) -// source: proto/ratelimit/v1/service.proto - -package ratelimitv1 - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" - unsafe "unsafe" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -// RatelimitRequest represents a request to check or consume rate limit tokens. -// This is typically the first point of contact when a client wants to verify -// if they are allowed to perform an action under the rate limit constraints. -type RatelimitRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - // Unique identifier for the rate limit subject. - // This could be: - // - A user ID - // - An API key - // - An IP address - // - Any other unique identifier that needs rate limiting - Identifier string `protobuf:"bytes,1,opt,name=identifier,proto3" json:"identifier,omitempty"` - // Maximum number of tokens allowed within the duration. - // Once this limit is reached, subsequent requests will be denied - // until there is more capacity. - Limit int64 `protobuf:"varint,2,opt,name=limit,proto3" json:"limit,omitempty"` - // Duration of the rate limit window in milliseconds. - // After this duration, a new window begins. - // Common values might be: - // - 1000 (1 second) - // - 60000 (1 minute) - // - 3600000 (1 hour) - Duration int64 `protobuf:"varint,3,opt,name=duration,proto3" json:"duration,omitempty"` - // Number of tokens to consume in this request. - // Higher values can be used for operations that should count more heavily - // against the rate limit (e.g., batch operations). - Cost int64 `protobuf:"varint,4,opt,name=cost,proto3" json:"cost,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *RatelimitRequest) Reset() { - *x = RatelimitRequest{} - mi := &file_proto_ratelimit_v1_service_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *RatelimitRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*RatelimitRequest) ProtoMessage() {} - -func (x *RatelimitRequest) ProtoReflect() protoreflect.Message { - mi := &file_proto_ratelimit_v1_service_proto_msgTypes[0] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use RatelimitRequest.ProtoReflect.Descriptor instead. -func (*RatelimitRequest) Descriptor() ([]byte, []int) { - return file_proto_ratelimit_v1_service_proto_rawDescGZIP(), []int{0} -} - -func (x *RatelimitRequest) GetIdentifier() string { - if x != nil { - return x.Identifier - } - return "" -} - -func (x *RatelimitRequest) GetLimit() int64 { - if x != nil { - return x.Limit - } - return 0 -} - -func (x *RatelimitRequest) GetDuration() int64 { - if x != nil { - return x.Duration - } - return 0 -} - -func (x *RatelimitRequest) GetCost() int64 { - if x != nil { - return x.Cost - } - return 0 -} - -// RatelimitResponse contains the result of a rate limit check. -// This response includes all necessary information for clients to understand -// their current rate limit status and when they can retry if limited. -type RatelimitResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - // Total limit configured for this window. - // This matches the limit specified in the request and is included - // for convenience in client implementations. - Limit int64 `protobuf:"varint,1,opt,name=limit,proto3" json:"limit,omitempty"` - // Number of tokens remaining in the current window. - // Clients can use this to implement progressive backoff or - // warn users when they're close to their limit. - Remaining int64 `protobuf:"varint,2,opt,name=remaining,proto3" json:"remaining,omitempty"` - // Unix timestamp (in milliseconds) when the current window expires. - // Clients can use this to: - // - Display time until reset to users - // - Implement automatic retry after window reset - // - Schedule future requests optimally - Reset_ int64 `protobuf:"varint,3,opt,name=reset,proto3" json:"reset,omitempty"` - // Whether the rate limit check was successful. - // true = request is allowed - // false = request is denied due to rate limit exceeded - Success bool `protobuf:"varint,4,opt,name=success,proto3" json:"success,omitempty"` - // Current token count in this window. - // This represents how many tokens have been consumed so far, - // useful for monitoring and debugging purposes. - Current int64 `protobuf:"varint,5,opt,name=current,proto3" json:"current,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *RatelimitResponse) Reset() { - *x = RatelimitResponse{} - mi := &file_proto_ratelimit_v1_service_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *RatelimitResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*RatelimitResponse) ProtoMessage() {} - -func (x *RatelimitResponse) ProtoReflect() protoreflect.Message { - mi := &file_proto_ratelimit_v1_service_proto_msgTypes[1] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use RatelimitResponse.ProtoReflect.Descriptor instead. -func (*RatelimitResponse) Descriptor() ([]byte, []int) { - return file_proto_ratelimit_v1_service_proto_rawDescGZIP(), []int{1} -} - -func (x *RatelimitResponse) GetLimit() int64 { - if x != nil { - return x.Limit - } - return 0 -} - -func (x *RatelimitResponse) GetRemaining() int64 { - if x != nil { - return x.Remaining - } - return 0 -} - -func (x *RatelimitResponse) GetReset_() int64 { - if x != nil { - return x.Reset_ - } - return 0 -} - -func (x *RatelimitResponse) GetSuccess() bool { - if x != nil { - return x.Success - } - return false -} - -func (x *RatelimitResponse) GetCurrent() int64 { - if x != nil { - return x.Current - } - return 0 -} - -// Window represents a rate limiting time window with its state. -// The system uses a sliding window approach to provide smooth -// rate limiting behavior across window boundaries. -type Window struct { - state protoimpl.MessageState `protogen:"open.v1"` - // Monotonically increasing sequence number for window ordering. - // The sequence is calculated as follows: - // sequence = time.Now().UnixMilli() / duration - Sequence int64 `protobuf:"varint,1,opt,name=sequence,proto3" json:"sequence,omitempty"` - // Duration of the window in milliseconds. - // This matches the duration from the original request and defines - // how long this window remains active. - Duration int64 `protobuf:"varint,2,opt,name=duration,proto3" json:"duration,omitempty"` - // Current token count in this window. - // This is the actual count of tokens consumed during this window's - // lifetime. It must never exceed the configured limit. - Counter int64 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` - // Start time of the window (Unix timestamp in milliseconds). - // Used to: - // - Calculate window expiration - // - Determine if a window is still active - // - Handle sliding window calculations between current and previous windows - Start int64 `protobuf:"varint,4,opt,name=start,proto3" json:"start,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *Window) Reset() { - *x = Window{} - mi := &file_proto_ratelimit_v1_service_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *Window) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*Window) ProtoMessage() {} - -func (x *Window) ProtoReflect() protoreflect.Message { - mi := &file_proto_ratelimit_v1_service_proto_msgTypes[2] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use Window.ProtoReflect.Descriptor instead. -func (*Window) Descriptor() ([]byte, []int) { - return file_proto_ratelimit_v1_service_proto_rawDescGZIP(), []int{2} -} - -func (x *Window) GetSequence() int64 { - if x != nil { - return x.Sequence - } - return 0 -} - -func (x *Window) GetDuration() int64 { - if x != nil { - return x.Duration - } - return 0 -} - -func (x *Window) GetCounter() int64 { - if x != nil { - return x.Counter - } - return 0 -} - -func (x *Window) GetStart() int64 { - if x != nil { - return x.Start - } - return 0 -} - -// ReplayRequest is used to synchronize rate limit state between nodes. -// This is a crucial part of maintaining consistency in a distributed -// rate limiting system. -type ReplayRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - // Original rate limit request that triggered the replay. - // Contains all the parameters needed to evaluate the rate limit - // on the origin server. - Request *RatelimitRequest `protobuf:"bytes,1,opt,name=request,proto3" json:"request,omitempty"` - // Time at which the request was received by the edge node. - // This is used to calculate the sequence number and determine - // the window in which the request falls. - Time int64 `protobuf:"varint,2,opt,name=time,proto3" json:"time,omitempty"` - // Indicates if the edge node denied the request. - // When false: The origin must increment the counter regardless of its own evaluation - // When true: The origin can evaluate the request fresh - // This field is crucial for maintaining consistency when edge nodes - // make defensive denials due to network issues or uncertainty. - Denied bool `protobuf:"varint,3,opt,name=denied,proto3" json:"denied,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *ReplayRequest) Reset() { - *x = ReplayRequest{} - mi := &file_proto_ratelimit_v1_service_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *ReplayRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*ReplayRequest) ProtoMessage() {} - -func (x *ReplayRequest) ProtoReflect() protoreflect.Message { - mi := &file_proto_ratelimit_v1_service_proto_msgTypes[3] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use ReplayRequest.ProtoReflect.Descriptor instead. -func (*ReplayRequest) Descriptor() ([]byte, []int) { - return file_proto_ratelimit_v1_service_proto_rawDescGZIP(), []int{3} -} - -func (x *ReplayRequest) GetRequest() *RatelimitRequest { - if x != nil { - return x.Request - } - return nil -} - -func (x *ReplayRequest) GetTime() int64 { - if x != nil { - return x.Time - } - return 0 -} - -func (x *ReplayRequest) GetDenied() bool { - if x != nil { - return x.Denied - } - return false -} - -// ReplayResponse contains the synchronized rate limit state that -// should be used to update both the origin and edge nodes. -type ReplayResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - // Current active window state. - // This represents the authoritative state of the current window - // as determined by the origin server. - Current *Window `protobuf:"bytes,1,opt,name=current,proto3" json:"current,omitempty"` - // Previous window state for sliding window calculations. - // Used to smooth out rate limiting across window boundaries and - // prevent sharp cliffs in availability during window transitions. - Previous *Window `protobuf:"bytes,2,opt,name=previous,proto3" json:"previous,omitempty"` - // Rate limit response that should be used by the edge node. - // This is the authoritative response that should be returned to - // the client and used to update edge state. - Response *RatelimitResponse `protobuf:"bytes,3,opt,name=response,proto3" json:"response,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *ReplayResponse) Reset() { - *x = ReplayResponse{} - mi := &file_proto_ratelimit_v1_service_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *ReplayResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*ReplayResponse) ProtoMessage() {} - -func (x *ReplayResponse) ProtoReflect() protoreflect.Message { - mi := &file_proto_ratelimit_v1_service_proto_msgTypes[4] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use ReplayResponse.ProtoReflect.Descriptor instead. -func (*ReplayResponse) Descriptor() ([]byte, []int) { - return file_proto_ratelimit_v1_service_proto_rawDescGZIP(), []int{4} -} - -func (x *ReplayResponse) GetCurrent() *Window { - if x != nil { - return x.Current - } - return nil -} - -func (x *ReplayResponse) GetPrevious() *Window { - if x != nil { - return x.Previous - } - return nil -} - -func (x *ReplayResponse) GetResponse() *RatelimitResponse { - if x != nil { - return x.Response - } - return nil -} - -var File_proto_ratelimit_v1_service_proto protoreflect.FileDescriptor - -const file_proto_ratelimit_v1_service_proto_rawDesc = "" + - "\n" + - " proto/ratelimit/v1/service.proto\x12\fratelimit.v1\"x\n" + - "\x10RatelimitRequest\x12\x1e\n" + - "\n" + - "identifier\x18\x01 \x01(\tR\n" + - "identifier\x12\x14\n" + - "\x05limit\x18\x02 \x01(\x03R\x05limit\x12\x1a\n" + - "\bduration\x18\x03 \x01(\x03R\bduration\x12\x12\n" + - "\x04cost\x18\x04 \x01(\x03R\x04cost\"\x91\x01\n" + - "\x11RatelimitResponse\x12\x14\n" + - "\x05limit\x18\x01 \x01(\x03R\x05limit\x12\x1c\n" + - "\tremaining\x18\x02 \x01(\x03R\tremaining\x12\x14\n" + - "\x05reset\x18\x03 \x01(\x03R\x05reset\x12\x18\n" + - "\asuccess\x18\x04 \x01(\bR\asuccess\x12\x18\n" + - "\acurrent\x18\x05 \x01(\x03R\acurrent\"p\n" + - "\x06Window\x12\x1a\n" + - "\bsequence\x18\x01 \x01(\x03R\bsequence\x12\x1a\n" + - "\bduration\x18\x02 \x01(\x03R\bduration\x12\x18\n" + - "\acounter\x18\x03 \x01(\x03R\acounter\x12\x14\n" + - "\x05start\x18\x04 \x01(\x03R\x05start\"u\n" + - "\rReplayRequest\x128\n" + - "\arequest\x18\x01 \x01(\v2\x1e.ratelimit.v1.RatelimitRequestR\arequest\x12\x12\n" + - "\x04time\x18\x02 \x01(\x03R\x04time\x12\x16\n" + - "\x06denied\x18\x03 \x01(\bR\x06denied\"\xaf\x01\n" + - "\x0eReplayResponse\x12.\n" + - "\acurrent\x18\x01 \x01(\v2\x14.ratelimit.v1.WindowR\acurrent\x120\n" + - "\bprevious\x18\x02 \x01(\v2\x14.ratelimit.v1.WindowR\bprevious\x12;\n" + - "\bresponse\x18\x03 \x01(\v2\x1f.ratelimit.v1.RatelimitResponseR\bresponse2Y\n" + - "\x10RatelimitService\x12E\n" + - "\x06Replay\x12\x1b.ratelimit.v1.ReplayRequest\x1a\x1c.ratelimit.v1.ReplayResponse\"\x00B@Z>github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1;ratelimitv1b\x06proto3" - -var ( - file_proto_ratelimit_v1_service_proto_rawDescOnce sync.Once - file_proto_ratelimit_v1_service_proto_rawDescData []byte -) - -func file_proto_ratelimit_v1_service_proto_rawDescGZIP() []byte { - file_proto_ratelimit_v1_service_proto_rawDescOnce.Do(func() { - file_proto_ratelimit_v1_service_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_proto_ratelimit_v1_service_proto_rawDesc), len(file_proto_ratelimit_v1_service_proto_rawDesc))) - }) - return file_proto_ratelimit_v1_service_proto_rawDescData -} - -var file_proto_ratelimit_v1_service_proto_msgTypes = make([]protoimpl.MessageInfo, 5) -var file_proto_ratelimit_v1_service_proto_goTypes = []any{ - (*RatelimitRequest)(nil), // 0: ratelimit.v1.RatelimitRequest - (*RatelimitResponse)(nil), // 1: ratelimit.v1.RatelimitResponse - (*Window)(nil), // 2: ratelimit.v1.Window - (*ReplayRequest)(nil), // 3: ratelimit.v1.ReplayRequest - (*ReplayResponse)(nil), // 4: ratelimit.v1.ReplayResponse -} -var file_proto_ratelimit_v1_service_proto_depIdxs = []int32{ - 0, // 0: ratelimit.v1.ReplayRequest.request:type_name -> ratelimit.v1.RatelimitRequest - 2, // 1: ratelimit.v1.ReplayResponse.current:type_name -> ratelimit.v1.Window - 2, // 2: ratelimit.v1.ReplayResponse.previous:type_name -> ratelimit.v1.Window - 1, // 3: ratelimit.v1.ReplayResponse.response:type_name -> ratelimit.v1.RatelimitResponse - 3, // 4: ratelimit.v1.RatelimitService.Replay:input_type -> ratelimit.v1.ReplayRequest - 4, // 5: ratelimit.v1.RatelimitService.Replay:output_type -> ratelimit.v1.ReplayResponse - 5, // [5:6] is the sub-list for method output_type - 4, // [4:5] is the sub-list for method input_type - 4, // [4:4] is the sub-list for extension type_name - 4, // [4:4] is the sub-list for extension extendee - 0, // [0:4] is the sub-list for field type_name -} - -func init() { file_proto_ratelimit_v1_service_proto_init() } -func file_proto_ratelimit_v1_service_proto_init() { - if File_proto_ratelimit_v1_service_proto != nil { - return - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_proto_ratelimit_v1_service_proto_rawDesc), len(file_proto_ratelimit_v1_service_proto_rawDesc)), - NumEnums: 0, - NumMessages: 5, - NumExtensions: 0, - NumServices: 1, - }, - GoTypes: file_proto_ratelimit_v1_service_proto_goTypes, - DependencyIndexes: file_proto_ratelimit_v1_service_proto_depIdxs, - MessageInfos: file_proto_ratelimit_v1_service_proto_msgTypes, - }.Build() - File_proto_ratelimit_v1_service_proto = out.File - file_proto_ratelimit_v1_service_proto_goTypes = nil - file_proto_ratelimit_v1_service_proto_depIdxs = nil -} diff --git a/go/internal/services/ratelimit/bucket.go b/go/internal/services/ratelimit/bucket.go index 1854d0343d..af18dc28fd 100644 --- a/go/internal/services/ratelimit/bucket.go +++ b/go/internal/services/ratelimit/bucket.go @@ -6,7 +6,6 @@ import ( "sync" "time" - ratelimitv1 "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1" "github.com/unkeyed/unkey/go/pkg/otel/metrics" ) @@ -23,7 +22,7 @@ import ( // b := &bucket{ // limit: 100, // duration: time.Minute, -// windows: make(map[int64]*ratelimitv1.Window), +// windows: make(map[int64]window), // } // // b.mu.Lock() @@ -43,7 +42,7 @@ type bucket struct { // Protected by mu // Key: sequence number (calculated from time) // Value: window containing request counts - windows map[int64]*ratelimitv1.Window + windows map[int64]*window // strictUntil is when this bucket must sync with origin // Used after rate limit exceeded to ensure consistency @@ -96,21 +95,19 @@ func (b bucketKey) toString() string { // - bool: true if the bucket already existed, false if it was created func (s *service) getOrCreateBucket(key bucketKey) (*bucket, bool) { - s.bucketsMu.RLock() + s.bucketsMu.Lock() + defer s.bucketsMu.Unlock() b, exists := s.buckets[key.toString()] - s.bucketsMu.RUnlock() if !exists { metrics.Ratelimit.CreatedWindows.Add(context.Background(), 1) b = &bucket{ mu: sync.RWMutex{}, limit: key.limit, duration: key.duration, - windows: make(map[int64]*ratelimitv1.Window), + windows: make(map[int64]*window), strictUntil: time.Time{}, } - s.bucketsMu.Lock() s.buckets[key.toString()] = b - s.bucketsMu.Unlock() } return b, exists } @@ -124,14 +121,12 @@ func (s *service) getOrCreateBucket(key bucketKey) (*bucket, bool) { // - now: Current time to determine the window // // Returns: -// - *ratelimitv1.Window: The current window +// - window: The current window // - bool: True if window existed, false if created // // Thread Safety: // - Caller MUST hold bucket.mu lock -// -// Performance: O(1) time and space complexity -func (b *bucket) getCurrentWindow(now time.Time) (*ratelimitv1.Window, bool) { +func (b *bucket) getCurrentWindow(now time.Time) (*window, bool) { sequence := calculateSequence(now, b.duration) @@ -150,14 +145,14 @@ func (b *bucket) getCurrentWindow(now time.Time) (*ratelimitv1.Window, bool) { // - now: Current time to determine the previous window // // Returns: -// - *ratelimitv1.Window: The previous window +// - window: The previous window // - bool: True if window existed, false if created // // Thread Safety: // - Caller MUST hold bucket.mu lock // // Performance: O(1) time and space complexity -func (b *bucket) getPreviousWindow(now time.Time) (*ratelimitv1.Window, bool) { +func (b *bucket) getPreviousWindow(now time.Time) (*window, bool) { sequence := calculateSequence(now, b.duration) - 1 diff --git a/go/internal/services/ratelimit/interface.go b/go/internal/services/ratelimit/interface.go index dc945eecaf..631a37133d 100644 --- a/go/internal/services/ratelimit/interface.go +++ b/go/internal/services/ratelimit/interface.go @@ -4,8 +4,6 @@ package ratelimit import ( "context" "time" - - ratelimitv1 "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1" ) // Service defines the core rate limiting functionality. It provides thread-safe @@ -90,6 +88,10 @@ type RatelimitRequest struct { // Must be >= 0. Defaults to 1 if not specified. // The request will be denied if Cost > Limit. Cost int64 + + // Time of the request + // If not specified or zero, the ratelimiter will use its own clock. + Time time.Time } // RatelimitResponse contains the result of a rate limit check and the current state @@ -118,7 +120,7 @@ type RatelimitResponse struct { // - Implement automatic retry after window reset // - Schedule future requests optimally // - Calculate backoff periods - Reset int64 + Reset time.Time // Success indicates whether the rate limit check passed. // true = request is allowed @@ -133,14 +135,6 @@ type RatelimitResponse struct { // - Understanding usage patterns // - Implementing custom backoff strategies Current int64 - - // CurrentWindow contains detailed state about the current time window - // including exact start time, duration, and counter. - CurrentWindow *ratelimitv1.Window - - // PreviousWindow contains state about the previous time window, - // used internally for sliding window calculations. - PreviousWindow *ratelimitv1.Window } // Middleware defines a function type that wraps a Service with additional functionality. diff --git a/go/internal/services/ratelimit/janitor.go b/go/internal/services/ratelimit/janitor.go index db07ea851b..ff00c5cf4d 100644 --- a/go/internal/services/ratelimit/janitor.go +++ b/go/internal/services/ratelimit/janitor.go @@ -48,7 +48,7 @@ func (s *service) expireWindowsAndBuckets() { for bucketID, bucket := range s.buckets { bucket.mu.Lock() for sequence, window := range bucket.windows { - if s.clock.Now().UnixMilli() > (window.GetStart() + (3 * window.GetDuration())) { + if s.clock.Now().After(window.start.Add(3 * window.duration)) { delete(bucket.windows, sequence) metrics.Ratelimit.EvictedWindows.Add(ctx, 1) } else { diff --git a/go/internal/services/ratelimit/peers.go b/go/internal/services/ratelimit/peers.go deleted file mode 100644 index a26478c38f..0000000000 --- a/go/internal/services/ratelimit/peers.go +++ /dev/null @@ -1,198 +0,0 @@ -package ratelimit - -import ( - "context" - "net/http" - "strings" - - "connectrpc.com/connect" - "connectrpc.com/otelconnect" - "github.com/unkeyed/unkey/apps/agent/pkg/tracing" - "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1/ratelimitv1connect" - "github.com/unkeyed/unkey/go/pkg/cluster" - "github.com/unkeyed/unkey/go/pkg/fault" - "github.com/unkeyed/unkey/go/pkg/otel/metrics" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/metric" -) - -// peer represents a node in the rate limiting cluster that can handle -// rate limit requests. It maintains a connection to the remote node -// and provides methods for state synchronization. -// -// Thread Safety: -// - Immutable after creation -// - Safe for concurrent use -// -// Performance: -// - Connection pooling handled by HTTP client -// - Minimal memory footprint per peer -type peer struct { - // instance contains cluster metadata about the peer - instance cluster.Instance - - // client is the RPC client for communicating with the peer - // Thread-safe for concurrent use - client ratelimitv1connect.RatelimitServiceClient -} - -// syncPeers maintains the service's peer list by listening for cluster membership -// changes and removing peers that have left the cluster. This ensures the rate -// limiter only attempts to communicate with active cluster nodes. -// -// This method should be run in a separate goroutine as it blocks while listening -// for cluster events. -// -// Thread Safety: -// - Safe for concurrent access with other peer operations -// - Uses peerMu to protect peer list modifications -// -// Performance: -// - O(1) per peer removal -// - Minimal memory usage -// - Non-blocking for rate limit operations -// -// Example Usage: -// -// go service.syncPeers() // Start peer synchronization -func (s *service) syncPeers() { - for leave := range s.cluster.SubscribeLeave() { - - s.logger.Info("peer left", "peer", leave.ID) - s.peerMu.Lock() - delete(s.peers, leave.ID) - s.peerMu.Unlock() - } - -} - -// getPeer retrieves or creates a peer connection for the given key. -// The key is used with consistent hashing to determine which node -// should be the origin for a particular rate limit identifier. -// -// Parameters: -// - ctx: Context for cancellation and tracing -// - key: Consistent hash key to identify the peer -// -// Returns: -// - peer: The peer connection, either existing or newly created -// - error: Any errors during peer lookup or connection -// -// Thread Safety: -// - Safe for concurrent use -// - Uses read-write mutex for peer map access -// -// Performance: -// - O(1) for existing peers -// - Network round trip for new peer connections -// - Connection pooling reduces overhead -// -// Errors: -// - Returns error if peer instance cannot be found -// - Returns error if connection cannot be established -// -// Example: -// -// p, err := svc.getPeer(ctx, "user-123") -// if err != nil { -// return fmt.Errorf("failed to get peer: %w", err) -// } -// resp, err := p.client.Ratelimit(ctx, req) -func (s *service) getPeer(ctx context.Context, key string) (peer, error) { - ctx, span := tracing.Start(ctx, "getPeer") - defer span.End() - - var p peer - - defer func() { - metrics.Ratelimit.Origin.Add(ctx, 1, metric.WithAttributeSet(attribute.NewSet( - attribute.String("origin_instance_id", p.instance.ID), - ), - )) - }() - - s.peerMu.RLock() - p, ok := s.peers[key] - s.peerMu.RUnlock() - if ok { - return p, nil - } - - p, err := s.newPeer(context.Background(), key) - if err != nil { - return peer{}, err - } - s.peerMu.Lock() - s.peers[key] = p - s.peerMu.Unlock() - return p, nil - -} - -// newPeer creates a new peer connection to a cluster node. -// It establishes the RPC client connection and configures tracing. -// -// Parameters: -// - ctx: Context for cancellation and tracing -// - key: Consistent hash key to identify the peer -// -// Returns: -// - peer: Newly created peer connection -// - error: Any errors during peer creation -// -// Thread Safety: -// - Caller must hold peerMu lock -// - Resulting peer is safe for concurrent use -// -// Performance: -// - Network round trip for initial connection -// - Creates new HTTP client and interceptors -// -// Errors: -// - Returns error if instance lookup fails -// - Returns error if interceptor creation fails -// - Returns error if RPC client creation fails -// -// Example: -// -// s.peerMu.Lock() -// p, err := s.newPeer(ctx, "user-123") -// if err != nil { -// s.peerMu.Unlock() -// return fmt.Errorf("failed to create peer: %w", err) -// } -// s.peers[key] = p -// s.peerMu.Unlock() -func (s *service) newPeer(ctx context.Context, key string) (peer, error) { - ctx, span := tracing.Start(ctx, "ratelimit.newPeer") - defer span.End() - - s.peerMu.Lock() - defer s.peerMu.Unlock() - - instance, err := s.cluster.FindInstance(ctx, key) - if err != nil { - return peer{}, fault.Wrap(err, fault.WithDesc("failed to find instance", "The ratelimit origin could not be found.")) - } - - s.logger.Info("peer added", - "peer", instance.ID, - "address", instance.RpcAddr, - ) - rpcAddr := instance.RpcAddr - if !strings.Contains(rpcAddr, "://") { - rpcAddr = "http://" + rpcAddr - } - - interceptor, err := otelconnect.NewInterceptor( - otelconnect.WithTracerProvider(tracing.GetGlobalTraceProvider()), - otelconnect.WithoutServerPeerAttributes(), - ) - if err != nil { - s.logger.Error("failed to create interceptor", "error", err.Error()) - return peer{}, err - } - - c := ratelimitv1connect.NewRatelimitServiceClient(http.DefaultClient, rpcAddr, connect.WithInterceptors(interceptor)) - return peer{instance: instance, client: c}, nil -} diff --git a/go/internal/services/ratelimit/replay.go b/go/internal/services/ratelimit/replay.go index 63e37be9a6..3c35daf316 100644 --- a/go/internal/services/ratelimit/replay.go +++ b/go/internal/services/ratelimit/replay.go @@ -4,11 +4,9 @@ import ( "context" "time" - "connectrpc.com/connect" - ratelimitv1 "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1" + "github.com/unkeyed/unkey/go/pkg/assert" "github.com/unkeyed/unkey/go/pkg/otel/metrics" "github.com/unkeyed/unkey/go/pkg/otel/tracing" - "go.opentelemetry.io/otel/attribute" ) // replayRequests processes buffered rate limit events by synchronizing them with @@ -38,7 +36,7 @@ import ( // } func (s *service) replayRequests() { for req := range s.replayBuffer.Consume() { - _, err := s.syncWithOrigin(context.Background(), req) + err := s.syncWithOrigin(context.Background(), req) if err != nil { s.logger.Error("failed to replay request", "error", err.Error()) } @@ -47,46 +45,7 @@ func (s *service) replayRequests() { } -// syncWithOrigin synchronizes a rate limit decision with the origin node for -// a given identifier. This ensures consistent rate limiting across the cluster -// by having authoritative nodes for each rate limit bucket. -// -// The method may return (nil, nil) if the current node is the origin, -// indicating no synchronization was needed. -// -// Parameters: -// - ctx: Context for cancellation and tracing -// - req: The rate limit event to synchronize -// -// Returns: -// - *ReplayResponse: The origin's rate limit state, or nil if local node is origin -// - error: Any errors during synchronization -// -// Thread Safety: -// - Safe for concurrent use -// - Uses internal synchronization for state updates -// -// Performance: -// - Network round trip to origin node -// - Circuit breaker prevents cascading failures -// - 5s overall timeout, 2s per RPC attempt -// -// Errors: -// - Returns error if peer lookup fails -// - Returns error if RPC fails -// - Returns error if circuit breaker is open -// -// Example: -// -// resp, err := svc.syncWithOrigin(ctx, &ratelimitv1.ReplayRequest{ -// Request: &ratelimitv1.RatelimitRequest{ -// Identifier: "user-123", -// Limit: 100, -// Duration: 60000, // 1 minute -// }, -// Time: time.Now().UnixMilli(), -// }) -func (s *service) syncWithOrigin(ctx context.Context, req *ratelimitv1.ReplayRequest) (*ratelimitv1.ReplayResponse, error) { +func (s *service) syncWithOrigin(ctx context.Context, req RatelimitRequest) error { defer func(start time.Time) { metrics.Ratelimit.OriginSyncLatency.Record(ctx, time.Since(start).Milliseconds()) }(time.Now()) @@ -97,123 +56,41 @@ func (s *service) syncWithOrigin(ctx context.Context, req *ratelimitv1.ReplayReq ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - now := s.clock.Now() - - key := bucketKey{ - req.GetRequest().GetIdentifier(), - req.GetRequest().GetLimit(), - time.Duration(req.GetRequest().GetDuration()) * time.Millisecond, - }.toString() - p, err := s.getPeer(ctx, key) + err := assert.False(req.Time.IsZero(), "request time must not be zero when replaying") if err != nil { - tracing.RecordError(span, err) - - return nil, err + return err } - span.SetAttributes(attribute.String("originInstanceID", p.instance.ID)) - - if p.instance.ID == s.cluster.Self().ID { - // we're the origin, nothing to replay... - // nolint:nilnil - return nil, nil + key := bucketKey{ + identifier: req.Identifier, + limit: req.Limit, + duration: req.Duration, } - res, err := s.replayCircuitBreaker.Do(ctx, func(innerCtx context.Context) (*connect.Response[ratelimitv1.ReplayResponse], error) { + bucket, _ := s.getOrCreateBucket(key) + bucket.mu.Lock() + defer bucket.mu.Unlock() + currentWindow, _ := bucket.getCurrentWindow(req.Time) + + newCounter, err := s.replayCircuitBreaker.Do(ctx, func(innerCtx context.Context) (int64, error) { innerCtx, cancel = context.WithTimeout(innerCtx, 2*time.Second) defer cancel() - return p.client.Replay(innerCtx, connect.NewRequest(req)) + + return s.counter.Increment( + innerCtx, + counterKey(key, currentWindow.sequence), + req.Cost, + currentWindow.duration*3, + ) + }) if err != nil { tracing.RecordError(span, err) - return nil, err + return err } - - s.SetWindows(ctx, - setWindowRequest{ - Identifier: req.GetRequest().GetIdentifier(), - Limit: req.GetRequest().GetLimit(), - Counter: res.Msg.GetCurrent().GetCounter(), - Sequence: res.Msg.GetCurrent().GetSequence(), - Duration: time.Duration(req.GetRequest().GetDuration()) * time.Millisecond, - Time: now, - }, - setWindowRequest{ - Identifier: req.GetRequest().GetIdentifier(), - Limit: req.GetRequest().GetLimit(), - Counter: res.Msg.GetPrevious().GetCounter(), - Sequence: res.Msg.GetPrevious().GetSequence(), - Duration: time.Duration(req.GetRequest().GetDuration()) * time.Millisecond, - Time: now, - }, - ) - - return res.Msg, nil -} - -// Replay handles incoming RPC requests to synchronize rate limit state. -// It is called by other nodes in the cluster when they need to verify -// their local rate limit decisions with this node (when this node is -// the origin for a particular identifier). -// -// Parameters: -// - ctx: Context for cancellation and tracing -// - req: The rate limit event to verify -// -// Returns: -// - *Response[ReplayResponse]: Current rate limit state -// - error: Any errors during processing -// -// Thread Safety: -// - Safe for concurrent use -// - Uses bucket-level locking -// -// Performance: -// - O(1) time complexity -// - No network calls -// - Uses local state only -// -// RPC Interface: -// - Part of the RatelimitService gRPC interface -// - Called automatically by peer nodes -// - Handles cluster-wide state synchronization -// -// Example: -// -// // Called via RPC from other nodes -// resp, err := svc.Replay(ctx, connect.NewRequest(&ratelimitv1.ReplayRequest{ -// Request: &ratelimitv1.RatelimitRequest{ -// Identifier: "user-123", -// Limit: 100, -// Duration: 60000, -// }, -// Time: time.Now().UnixMilli(), -// })) -func (s *service) Replay(ctx context.Context, req *connect.Request[ratelimitv1.ReplayRequest]) (*connect.Response[ratelimitv1.ReplayResponse], error) { - ctx, span := tracing.Start(ctx, "Replay") - defer span.End() - t := time.UnixMilli(req.Msg.GetTime()) - - res, _, err := s.localRatelimit(ctx, t, RatelimitRequest{ - Identifier: req.Msg.GetRequest().GetIdentifier(), - Limit: req.Msg.GetRequest().GetLimit(), - Duration: time.Duration(req.Msg.GetRequest().GetDuration()) * time.Millisecond, - Cost: req.Msg.GetRequest().GetCost(), - }) - - if err != nil { - return nil, err + if newCounter > currentWindow.counter { + currentWindow.counter = newCounter } - return connect.NewResponse(&ratelimitv1.ReplayResponse{ - Current: res.CurrentWindow, - Previous: res.PreviousWindow, - Response: &ratelimitv1.RatelimitResponse{ - Limit: res.Limit, - Remaining: res.Remaining, - Reset_: res.Reset, - Success: res.Success, - Current: res.Current, - }, - }), nil + return nil } diff --git a/go/internal/services/ratelimit/service.go b/go/internal/services/ratelimit/service.go index 7dba2e0443..65a37cdd94 100644 --- a/go/internal/services/ratelimit/service.go +++ b/go/internal/services/ratelimit/service.go @@ -3,18 +3,13 @@ package ratelimit import ( "context" "sync" - "time" - "connectrpc.com/connect" - ratelimitv1 "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1" - "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1/ratelimitv1connect" "github.com/unkeyed/unkey/go/pkg/assert" "github.com/unkeyed/unkey/go/pkg/buffer" "github.com/unkeyed/unkey/go/pkg/circuitbreaker" "github.com/unkeyed/unkey/go/pkg/clock" - "github.com/unkeyed/unkey/go/pkg/cluster" + "github.com/unkeyed/unkey/go/pkg/counter" "github.com/unkeyed/unkey/go/pkg/otel/logging" - "github.com/unkeyed/unkey/go/pkg/otel/metrics" "github.com/unkeyed/unkey/go/pkg/otel/tracing" "go.opentelemetry.io/otel/attribute" ) @@ -68,9 +63,6 @@ type service struct { // logger handles structured logging output logger logging.Logger - // cluster manages node discovery and state distribution - cluster cluster.Cluster - // shutdownCh signals service shutdown shutdownCh chan struct{} @@ -81,28 +73,24 @@ type service struct { // Protected by bucketsMu buckets map[string]*bucket - // peerMu protects access to peer-related fields - peerMu sync.RWMutex - - // peers maps node IDs to peer connections - // Protected by peerMu - peers map[string]peer + // counter is the distributed counter implementation + counter counter.Counter // replayBuffer holds rate limit events for async propagation // Thread-safe internally - replayBuffer *buffer.Buffer[*ratelimitv1.ReplayRequest] + replayBuffer *buffer.Buffer[RatelimitRequest] // replayCircuitBreaker prevents cascading failures during peer communication // Thread-safe internally - replayCircuitBreaker circuitbreaker.CircuitBreaker[*connect.Response[ratelimitv1.ReplayResponse]] + replayCircuitBreaker circuitbreaker.CircuitBreaker[int64] } -var _ ratelimitv1connect.RatelimitServiceHandler = (*service)(nil) - type Config struct { - Logger logging.Logger - Cluster cluster.Cluster - Clock clock.Clock + Logger logging.Logger + + Clock clock.Clock + // If provided, use this counter implementation instead of creating a Redis counter + Counter counter.Counter } // New creates a new rate limiting service with the given configuration. @@ -143,17 +131,14 @@ func New(config Config) (*service, error) { s := &service{ clock: config.Clock, logger: config.Logger, - cluster: config.Cluster, shutdownCh: make(chan struct{}), bucketsMu: sync.RWMutex{}, buckets: make(map[string]*bucket), - peerMu: sync.RWMutex{}, - peers: make(map[string]peer), - replayBuffer: buffer.New[*ratelimitv1.ReplayRequest](10_000, true), - replayCircuitBreaker: circuitbreaker.New[*connect.Response[ratelimitv1.ReplayResponse]]("replayRatelimitRequest"), + counter: config.Counter, + replayBuffer: buffer.New[RatelimitRequest](10_000, true), + replayCircuitBreaker: circuitbreaker.New[int64]("replayRatelimitRequest"), } - go s.syncPeers() s.expireWindowsAndBuckets() // start multiple goroutines to do replays @@ -164,6 +149,12 @@ func New(config Config) (*service, error) { return s, nil } +// Close releases all resources held by the rate limiter. +// It should be called when the service is no longer needed. +func (s *service) Close() error { + return s.counter.Close() +} + // Ratelimit checks if a request should be allowed under current rate limit constraints. // It implements a sliding window algorithm that considers both the current and previous // time windows to provide accurate rate limiting across a cluster of nodes. @@ -186,10 +177,6 @@ func New(config Config) (*service, error) { // - Returns validation errors for invalid parameters // - May return errors from cluster communication // -// Performance: -// - O(1) time complexity for local decisions -// - Network round trip only when syncing with origin -// // Thread Safety: // - Safe for concurrent use // - State updates are atomic @@ -208,119 +195,77 @@ func New(config Config) (*service, error) { // return fmt.Errorf("rate limit exceeded, retry after %v", // time.UnixMilli(resp.Reset)) // } -func (r *service) Ratelimit(ctx context.Context, req RatelimitRequest) (RatelimitResponse, error) { +func (s *service) Ratelimit(ctx context.Context, req RatelimitRequest) (RatelimitResponse, error) { + _, span := tracing.Start(ctx, "Ratelimit") defer span.End() + if req.Time.IsZero() { + req.Time = s.clock.Now() + } + err := assert.All( assert.NotEmpty(req.Identifier, "ratelimit identifier must not be empty"), assert.Greater(req.Limit, 0, "ratelimit limit must be greater than zero"), assert.GreaterOrEqual(req.Cost, 0, "ratelimit cost must not be negative"), assert.GreaterOrEqual(req.Duration.Milliseconds(), 1000, "ratelimit duration must be at least 1s"), + assert.False(req.Time.IsZero(), "request time must not be zero"), ) if err != nil { return RatelimitResponse{}, err } - now := r.clock.Now() - - localRes, goToOrigin, err := r.localRatelimit(ctx, now, req) - if err != nil { - return RatelimitResponse{}, err - } - - replayRequest := &ratelimitv1.ReplayRequest{ - Request: &ratelimitv1.RatelimitRequest{ - Identifier: req.Identifier, - Limit: req.Limit, - Duration: req.Duration.Milliseconds(), - Cost: req.Cost, - }, - Time: now.UnixMilli(), - Denied: !localRes.Success, - } - - if goToOrigin { - metrics.Ratelimit.OriginDecisions.Add(ctx, 1) - originRes, err := r.syncWithOrigin(ctx, replayRequest) - if err != nil { - r.logger.Error("unable to ask the origin", - "error", err.Error(), - "identifier", req.Identifier, - ) - } else if originRes != nil { - return RatelimitResponse{ - Limit: originRes.GetResponse().GetLimit(), - Remaining: originRes.GetResponse().GetRemaining(), - Reset: originRes.GetResponse().GetReset_(), - Success: originRes.GetResponse().GetSuccess(), - Current: originRes.GetResponse().GetCurrent(), - CurrentWindow: originRes.GetCurrent(), - PreviousWindow: originRes.GetPrevious(), - }, nil - } - } else { - r.replayBuffer.Buffer(replayRequest) - metrics.Ratelimit.LocalDecisions.Add(ctx, 1) - - } - - return localRes, nil -} - -// localRatelimit performs a rate limit check using only local state. -// It implements the core sliding window algorithm and determines if -// synchronization with the origin node is needed. -// -// Parameters: -// - ctx: Context for cancellation and tracing -// - now: Current time (from service clock) -// - req: The rate limit request -// -// Returns: -// - RatelimitResponse: The local rate limit decision -// - bool: True if sync with origin is needed -// - error: Any errors during processing -// -// Algorithm: -// 1. Gets or creates local rate limit bucket -// 2. Calculates effective count using sliding window -// 3. Updates local state if request is allowed -// 4. Determines if origin sync is needed -// -// Thread Safety: -// - Protected by bucket mutex -// - Safe for concurrent calls -// -// Performance: O(1) time complexity -func (r *service) localRatelimit(ctx context.Context, now time.Time, req RatelimitRequest) (RatelimitResponse, bool, error) { - _, span := tracing.Start(ctx, "localRatelimit") - defer span.End() - key := bucketKey{req.Identifier, req.Limit, req.Duration} span.SetAttributes(attribute.String("key", key.toString())) - b, _ := r.getOrCreateBucket(key) + b, _ := s.getOrCreateBucket(key) b.mu.Lock() defer b.mu.Unlock() - goToOrigin := now.UnixMilli() < b.strictUntil.UnixMilli() + goToOrigin := req.Time.UnixMilli() < b.strictUntil.UnixMilli() // Get current and previous windows - currentWindow, currentWindowExisted := b.getCurrentWindow(now) - previousWindow, _ := b.getPreviousWindow(now) + currentWindow, currentWindowExisted := b.getCurrentWindow(req.Time) + previousWindow, previousWindowExisted := b.getPreviousWindow(req.Time) + + refreshKeys := []string{} + currentKey := "" + previousKey := "" + + if goToOrigin || !currentWindowExisted { + currentKey = counterKey(key, currentWindow.sequence) + refreshKeys = append(refreshKeys, currentKey) + + } + if goToOrigin || !previousWindowExisted { + previousKey = counterKey(key, previousWindow.sequence) + refreshKeys = append(refreshKeys, previousKey) + } + + if len(refreshKeys) > 0 { + res, err := s.counter.MultiGet(ctx, refreshKeys) + if err != nil { + s.logger.Error("unable to get counter values", + "keys", refreshKeys, + "error", err.Error(), + ) + } + if counter := res[currentKey]; counter > currentWindow.counter { + currentWindow.counter = counter + } + if counter := res[previousKey]; counter > previousWindow.counter { + previousWindow.counter = counter + } - if !currentWindowExisted { - goToOrigin = true } // Calculate time elapsed in current window (as a fraction) - windowElapsed := float64(now.UnixMilli()-currentWindow.GetStart()) / float64(req.Duration.Milliseconds()) + windowElapsed := float64(req.Time.Sub(currentWindow.start).Milliseconds()) / float64(req.Duration.Milliseconds()) // Pure sliding window calculation: // - We count 100% of current window // - We count a decreasing portion of previous window based on how far we are into current window - effectiveCount := currentWindow.GetCounter() + int64(float64(previousWindow.GetCounter())*(1.0-windowElapsed)) + effectiveCount := currentWindow.counter + int64(float64(previousWindow.counter)*(1.0-windowElapsed)) effectiveCount += req.Cost @@ -331,22 +276,20 @@ func (r *service) localRatelimit(ctx context.Context, now time.Time, req Ratelim remaining = 0 } - b.strictUntil = now.Add(req.Duration) + b.strictUntil = req.Time.Add(req.Duration) span.SetAttributes(attribute.Bool("passed", false)) return RatelimitResponse{ - Success: false, - Remaining: remaining, - Reset: currentWindow.GetStart() + currentWindow.GetDuration(), - Limit: req.Limit, - Current: effectiveCount, - CurrentWindow: currentWindow, - PreviousWindow: previousWindow, - }, goToOrigin, nil + Success: false, + Remaining: remaining, + Reset: currentWindow.start.Add(currentWindow.duration), + Limit: req.Limit, + Current: effectiveCount, + }, nil } // Increment current window counter - currentWindow.Counter += req.Cost + currentWindow.counter += req.Cost remaining := req.Limit - effectiveCount if remaining < 0 { @@ -354,13 +297,13 @@ func (r *service) localRatelimit(ctx context.Context, now time.Time, req Ratelim } span.SetAttributes(attribute.Bool("passed", true)) + s.replayBuffer.Buffer(req) + return RatelimitResponse{ - Success: true, - Remaining: remaining, - Reset: currentWindow.GetStart() + currentWindow.GetDuration(), - Limit: req.Limit, - Current: currentWindow.GetCounter(), - CurrentWindow: currentWindow, - PreviousWindow: previousWindow, - }, goToOrigin, nil + Success: true, + Remaining: remaining, + Reset: currentWindow.start.Add(currentWindow.duration), + Limit: req.Limit, + Current: currentWindow.counter, + }, nil } diff --git a/go/internal/services/ratelimit/util.go b/go/internal/services/ratelimit/util.go new file mode 100644 index 0000000000..4f8758d822 --- /dev/null +++ b/go/internal/services/ratelimit/util.go @@ -0,0 +1,7 @@ +package ratelimit + +import "fmt" + +func counterKey(b bucketKey, seq int64) string { + return fmt.Sprintf("%s:%d", b.toString(), seq) +} diff --git a/go/internal/services/ratelimit/window.go b/go/internal/services/ratelimit/window.go index 138c71ded1..9244417aa2 100644 --- a/go/internal/services/ratelimit/window.go +++ b/go/internal/services/ratelimit/window.go @@ -4,10 +4,28 @@ import ( "context" "time" - ratelimitv1 "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1" "github.com/unkeyed/unkey/go/pkg/otel/metrics" ) +type window struct { + // sequence = time.Now().UnixMilli() / duration + sequence int64 + // Duration of the window in milliseconds. + // This matches the duration from the original request and defines + // how long this window remains active. + duration time.Duration + // Current token count in this window. + // This is the actual count of tokens consumed during this window's + // lifetime. It must never exceed the configured limit. + counter int64 + // Start time of the window (Unix timestamp in milliseconds). + // Used to: + // - Calculate window expiration + // - Determine if a window is still active + // - Handle sliding window calculations between current and previous windows + start time.Time +} + // newWindow creates a new rate limit window starting at the given time. // Windows are aligned to duration boundaries (e.g., on the minute for // minute-based limits) to ensure consistent behavior across nodes. @@ -37,92 +55,12 @@ import ( // time.Now(), // time.Minute, // ) -func newWindow(sequence int64, t time.Time, duration time.Duration) *ratelimitv1.Window { +func newWindow(sequence int64, t time.Time, duration time.Duration) *window { metrics.Ratelimit.CreatedWindows.Add(context.Background(), 1) - return &ratelimitv1.Window{ - Sequence: sequence, - Start: t.Truncate(duration).UnixMilli(), - Duration: duration.Milliseconds(), - Counter: 0, - } -} - -// setWindowRequest contains parameters for updating a rate limit window. -// Used to synchronize window state across cluster nodes. -// -// Thread Safety: -// - Immutable after creation -// - Safe for concurrent use -type setWindowRequest struct { - // Identifier uniquely identifies the rate limit subject - Identifier string - - // Limit is the maximum allowed requests per duration - Limit int64 - - // Duration is the time window length - Duration time.Duration - - // Sequence uniquely identifies this window - Sequence int64 - - // Time is any timestamp within the target window - // Will be aligned to window boundaries - Time time.Time - - // Counter is the new request count for this window - Counter int64 -} - -// SetWindows updates the state of one or more rate limit windows. -// Used to synchronize window state across cluster nodes and handle -// replay requests from other nodes. -// -// Parameters: -// - ctx: Context for cancellation and tracing -// - requests: Window states to update -// -// Thread Safety: -// - Safe for concurrent use -// - Updates are atomic per window -// -// Performance: -// - O(n) where n is number of requests -// - Acquires/releases bucket mutex for each window -// -// Behavior: -// - Only increases window counters, never decreases -// - Creates missing windows/buckets as needed -// - Maintains monotonic counter invariant -// -// Example: -// -// svc.SetWindows(ctx, setWindowRequest{ -// Identifier: "user-123", -// Limit: 100, -// Duration: time.Minute, -// Sequence: 42, -// Time: time.Now(), -// Counter: 5, -// }) -func (r *service) SetWindows(ctx context.Context, requests ...setWindowRequest) { - for _, req := range requests { - key := bucketKey{req.Identifier, req.Limit, req.Duration} - bucket, _ := r.getOrCreateBucket(key) - bucket.mu.Lock() - window, ok := bucket.windows[req.Sequence] - if !ok { - window = newWindow(req.Sequence, req.Time, req.Duration) - bucket.windows[req.Sequence] = window - } - - // Only increment the current value if the new value is greater than the current value - // Due to varying network latency, we may receive out of order responses and could decrement the - // current value, which would result in inaccurate rate limiting - if req.Counter > window.GetCounter() { - window.Counter = req.Counter - } - bucket.mu.Unlock() - + return &window{ + sequence: sequence, + start: t.Truncate(duration), + duration: duration, + counter: 0, } } diff --git a/go/pkg/counter/doc.go b/go/pkg/counter/doc.go new file mode 100644 index 0000000000..7a8f59d70d --- /dev/null +++ b/go/pkg/counter/doc.go @@ -0,0 +1,43 @@ +/* +Package counter provides abstractions and implementations for distributed counters. + +This package contains interfaces and concrete implementations for tracking and +incrementing counter values in distributed environments. It can be used for +various purposes such as rate limiting, usage tracking, and statistics. + +Architecture: + - Uses a simple interface that can be implemented with various backends + - Provides a Redis implementation for distributed scenarios + - Supports middleware pattern for extending functionality + +Thread Safety: + - All implementations are safe for concurrent use + - Operations are atomic across distributed systems (depending on implementation) + +Example Usage: + + import "github.com/unkeyed/unkey/go/pkg/counter" + + // Create a Redis-backed counter + redisCounter, err := counter.NewRedis(counter.RedisConfig{ + RedisURL: "redis://localhost:6379", + Logger: logger, + }) + if err != nil { + return err + } + defer redisCounter.Close() + + // Increment a counter + newValue, err := redisCounter.Increment(ctx, "my-counter", 1) + if err != nil { + return err + } + + // Get a counter value + value, err := redisCounter.Get(ctx, "my-counter") + if err != nil { + return err + } +*/ +package counter \ No newline at end of file diff --git a/go/pkg/counter/interface.go b/go/pkg/counter/interface.go new file mode 100644 index 0000000000..307e4c06ec --- /dev/null +++ b/go/pkg/counter/interface.go @@ -0,0 +1,80 @@ +// Package counter provides abstractions for distributed counters. +// It defines interfaces and implementations for tracking and incrementing +// counters in a distributed environment. +package counter + +import ( + "context" + "time" +) + +// Counter defines the interface for a distributed counter. +// It provides operations to increment and retrieve counter values +// in a thread-safe and distributed manner. +// +// Implementations of this interface are expected to handle: +// - Thread safety for concurrent operations +// - Persistence of counter values +// - Distribution across nodes if applicable +// +// Concurrency: All methods are safe for concurrent use. +type Counter interface { + // Increment increases the counter by the given value and returns the new count. + // + // Parameters: + // - ctx: Context for cancellation and tracing + // - key: Unique identifier for the counter + // - value: Amount to increment the counter by + // - ttl: Optional time-to-live duration for the counter. If provided and + // the counter is newly created, implementations should set this TTL. + // If nil, no TTL is set. + // + // Returns: + // - int64: The new counter value after incrementing + // - error: Any errors that occurred during the operation + Increment(ctx context.Context, key string, value int64, ttl ...time.Duration) (int64, error) + + // Get retrieves the current value of a counter. + // + // Parameters: + // - ctx: Context for cancellation and tracing + // - key: Unique identifier for the counter + // + // Returns: + // - int64: The current counter value + // - error: Any errors that occurred during the operation + // If the counter doesn't exist, implementations should + // return 0 and nil error, not an error. + Get(ctx context.Context, key string) (int64, error) + + // MultiGet retrieves the values of multiple counters in a single operation. + // + // Parameters: + // - ctx: Context for cancellation and tracing + // - keys: Slice of unique identifiers for the counters + // + // Returns: + // - map[string]int64: Map of counter keys to their current values + // - error: Any errors that occurred during the operation + // If a counter doesn't exist, its value in the map will be 0. + MultiGet(ctx context.Context, keys []string) (map[string]int64, error) + + // Close releases any resources held by the counter implementation. + // After calling Close(), the counter instance should not be used again. + // + // Returns: + // - error: Any errors that occurred during shutdown + Close() error +} + +// Middleware defines a function type that wraps a Counter with additional functionality. +// It can be used to add logging, metrics, validation, or other cross-cutting concerns. +// +// Example Usage: +// +// func LoggingMiddleware(logger Logger) Middleware { +// return func(next Counter) Counter { +// return &loggingCounter{next: next, logger: logger} +// } +// } +type Middleware func(Counter) Counter diff --git a/go/pkg/counter/redis.go b/go/pkg/counter/redis.go new file mode 100644 index 0000000000..55796c3cdd --- /dev/null +++ b/go/pkg/counter/redis.go @@ -0,0 +1,193 @@ +package counter + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/redis/go-redis/v9" + "github.com/unkeyed/unkey/go/pkg/assert" + "github.com/unkeyed/unkey/go/pkg/otel/logging" + "github.com/unkeyed/unkey/go/pkg/otel/tracing" +) + +// redisCounter implements the Counter interface using Redis. +// It provides distributed counter functionality backed by Redis. +type redisCounter struct { + // redis is the Redis client + redis *redis.Client + + // logger for logging + logger logging.Logger +} + +var _ Counter = (*redisCounter)(nil) + +// RedisConfig holds configuration options for the Redis counter. +type RedisConfig struct { + // RedisURL is the connection URL for Redis. + // Format: redis://[[username][:password]@][host][:port][/database] + RedisURL string + + // Logger is the logging implementation to use. + // Optional, but recommended for production use. + Logger logging.Logger +} + +// NewRedis creates a new Redis-backed counter implementation. +// +// Parameters: +// - config: Configuration options for the Redis counter +// +// Returns: +// - Counter: Redis implementation of the Counter interface +// - error: Any errors during initialization +func NewRedis(config RedisConfig) (Counter, error) { + err := assert.All( + assert.NotEmpty(config.RedisURL, "Redis URL must not be empty"), + ) + if err != nil { + return nil, err + } + + opts, err := redis.ParseURL(config.RedisURL) + if err != nil { + return nil, fmt.Errorf("failed to parse redis url: %w", err) + } + + rdb := redis.NewClient(opts) + config.Logger.Debug("pinging redis") + + // Test connection + _, err = rdb.Ping(context.Background()).Result() + if err != nil { + return nil, fmt.Errorf("failed to ping redis: %w", err) + } + + return &redisCounter{ + redis: rdb, + logger: config.Logger, + }, nil +} + +// Increment increases the counter by the given value and returns the new count. +// If ttl is provided and the counter is newly created (new value is equal to the increment value), +// it also sets an expiration time for the counter. +// +// Parameters: +// - ctx: Context for cancellation and tracing +// - key: Unique identifier for the counter +// - value: Amount to increment the counter by +// - ttl: Optional time-to-live duration. If provided and the key is new, sets this TTL. +// +// Returns: +// - int64: The new counter value after incrementing +// - error: Any errors that occurred during the operation +func (r *redisCounter) Increment(ctx context.Context, key string, value int64, ttl ...time.Duration) (int64, error) { + ctx, span := tracing.Start(ctx, "RedisCounter.Increment") + defer span.End() + + // Increment the counter + newValue, err := r.redis.IncrBy(ctx, key, value).Result() + if err != nil { + return 0, err + } + + // If TTL is provided and this is a new key (value == increment amount), + // set the expiration time + if len(ttl) > 0 && newValue == value { + if err := r.redis.Expire(ctx, key, ttl[0]).Err(); err != nil { + r.logger.Error("failed to set TTL on counter", "key", key, "error", err.Error()) + // We don't return the error since the increment operation was successful + } + } + + return newValue, nil +} + +// Get retrieves the current value of a counter. +// +// Parameters: +// - ctx: Context for cancellation and tracing +// - key: Unique identifier for the counter +// +// Returns: +// - int64: The current counter value +// - error: Any errors that occurred during the operation +func (r *redisCounter) Get(ctx context.Context, key string) (int64, error) { + ctx, span := tracing.Start(ctx, "RedisCounter.Get") + defer span.End() + + res, err := r.redis.Get(ctx, key).Result() + if err == redis.Nil { + // Key doesn't exist, return 0 without error + return 0, nil + } + if err != nil { + return 0, err + } + return strconv.ParseInt(res, 10, 64) +} + +// Close releases the Redis client connection. +// +// Returns: +// - error: Any errors that occurred during shutdown +func (r *redisCounter) Close() error { + return r.redis.Close() +} + +// MultiGet retrieves the values of multiple counters in a single operation. +// +// Parameters: +// - ctx: Context for cancellation and tracing +// - keys: Slice of unique identifiers for the counters +// +// Returns: +// - map[string]int64: Map of counter keys to their current values +// - error: Any errors that occurred during the operation +func (r *redisCounter) MultiGet(ctx context.Context, keys []string) (map[string]int64, error) { + ctx, span := tracing.Start(ctx, "RedisCounter.MultiGet") + defer span.End() + + if len(keys) == 0 { + return make(map[string]int64), nil + } + + values, err := r.redis.MGet(ctx, keys...).Result() + if err != nil { + return nil, err + } + + result := make(map[string]int64, len(keys)) + for i, key := range keys { + if i >= len(values) || values[i] == nil { + // Key doesn't exist, set to 0 + result[key] = 0 + continue + } + + s, ok := values[i].(string) + if !ok { + r.logger.Warn("unexpected type for counter value", + "key", key, + "type", fmt.Sprintf("%T", values[i]), + ) + continue + } + v, err := strconv.ParseInt(s, 10, 64) + if err != nil { + r.logger.Warn("failed to parse counter value", + "key", key, + "value", s, + "error", err, + ) + continue + } + + result[key] = v + } + + return result, nil +} diff --git a/go/pkg/counter/redis_test.go b/go/pkg/counter/redis_test.go new file mode 100644 index 0000000000..5aac6c5794 --- /dev/null +++ b/go/pkg/counter/redis_test.go @@ -0,0 +1,458 @@ +package counter + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/pkg/otel/logging" + "github.com/unkeyed/unkey/go/pkg/testutil/containers" +) + +func TestRedisCounter(t *testing.T) { + ctx := context.Background() + containers := containers.New(t) + _, redisURL, _ := containers.RunRedis() + + // Create a Redis counter + ctr, err := NewRedis(RedisConfig{ + RedisURL: redisURL, + Logger: logging.New(), + }) + require.NoError(t, err) + defer ctr.Close() + + // Test basic increment + t.Run("BasicIncrement", func(t *testing.T) { + key := "test:increment" + + // First increment should return 1 + val, err := ctr.Increment(ctx, key, 1) + require.NoError(t, err) + assert.Equal(t, int64(1), val) + + // Second increment should return 2 + val, err = ctr.Increment(ctx, key, 1) + require.NoError(t, err) + assert.Equal(t, int64(2), val) + + // Increment by 5 should return 7 + val, err = ctr.Increment(ctx, key, 5) + require.NoError(t, err) + assert.Equal(t, int64(7), val) + }) + + t.Run("IncrementWithTTL", func(t *testing.T) { + key := "test:increment:ttl" + ttl := 1 * time.Second + + // First increment with TTL + val, err := ctr.Increment(ctx, key, 1, ttl) + require.NoError(t, err) + assert.Equal(t, int64(1), val) + + // Get the value immediately + val, err = ctr.Get(ctx, key) + require.NoError(t, err) + assert.Equal(t, int64(1), val) + + // Wait for the key to expire + time.Sleep(2 * time.Second) + + // Key should be gone or zero + val, err = ctr.Get(ctx, key) + require.NoError(t, err) + assert.Equal(t, int64(0), val) + }) + + t.Run("Get", func(t *testing.T) { + key := "test:get" + + // Get non-existent key + val, err := ctr.Get(ctx, key) + require.NoError(t, err) + assert.Equal(t, int64(0), val) + + // Set a value and get it + _, err = ctr.Increment(ctx, key, 42) + require.NoError(t, err) + + val, err = ctr.Get(ctx, key) + require.NoError(t, err) + assert.Equal(t, int64(42), val) + }) + + // Table-driven tests for multiple increments + t.Run("TableDrivenIncrements", func(t *testing.T) { + tests := []struct { + name string + key string + increments []int64 + expected int64 + }{ + { + name: "Single increment", + key: "test:table:single", + increments: []int64{5}, + expected: 5, + }, + { + name: "Multiple increments", + key: "test:table:multiple", + increments: []int64{1, 2, 3, 4, 5}, + expected: 15, + }, + { + name: "Mixed positive and negative", + key: "test:table:mixed", + increments: []int64{10, -3, 5, -2}, + expected: 10, + }, + { + name: "Zero sum", + key: "test:table:zero", + increments: []int64{5, -5, 10, -10}, + expected: 0, + }, + { + name: "Large increments", + key: "test:table:large", + increments: []int64{1000, 2000, 3000}, + expected: 6000, + }, + } + + for _, tc := range tests { + tc := tc // Capture range variable for parallel execution + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var finalValue int64 + var err error + + for _, inc := range tc.increments { + finalValue, err = ctr.Increment(ctx, tc.key, inc) + require.NoError(t, err) + } + + assert.Equal(t, tc.expected, finalValue) + + // Verify with Get also + value, err := ctr.Get(ctx, tc.key) + require.NoError(t, err) + assert.Equal(t, tc.expected, value) + }) + } + }) + + // Test concurrent increments + t.Run("ConcurrentIncrements", func(t *testing.T) { + tests := []struct { + name string + key string + goroutines int + incrementsEach int + value int64 + expected int64 + }{ + { + name: "Few goroutines, many increments", + key: "test:concurrent:few", + goroutines: 5, + incrementsEach: 100, + value: 1, + expected: 500, // 5 * 100 * 1 + }, + { + name: "Many goroutines, few increments", + key: "test:concurrent:many", + goroutines: 50, + incrementsEach: 10, + value: 1, + expected: 500, // 50 * 10 * 1 + }, + { + name: "Medium scale mixed values", + key: "test:concurrent:mixed", + goroutines: 20, + incrementsEach: 20, + value: 5, + expected: 2000, // 20 * 20 * 5 + }, + { + name: "High contention with negative values", + key: "test:concurrent:negative", + goroutines: 30, + incrementsEach: 10, + value: -2, + expected: -600, // 30 * 10 * -2 + }, + } + + for _, tc := range tests { + tc := tc // Capture range variable for parallel execution + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var wg sync.WaitGroup + wg.Add(tc.goroutines) + + for i := 0; i < tc.goroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < tc.incrementsEach; j++ { + _, err := ctr.Increment(ctx, tc.key, tc.value) + if err != nil { + t.Errorf("increment error: %v", err) + return + } + } + }() + } + + wg.Wait() + + // Verify final value + value, err := ctr.Get(ctx, tc.key) + require.NoError(t, err) + assert.Equal(t, tc.expected, value, "Final counter value doesn't match expected") + }) + } + }) + + // Test interleaved operations (increment and get mixed together) + t.Run("InterleavedOperations", func(t *testing.T) { + key := "test:interleaved" + numWorkers := 10 + operationsPerWorker := 50 + + var wg sync.WaitGroup + wg.Add(numWorkers) + + // Launch goroutines that both increment and get values + for i := 0; i < numWorkers; i++ { + go func(id int) { + defer wg.Done() + + for j := 0; j < operationsPerWorker; j++ { + // Alternate between increment and get + if j%2 == 0 { + _, err := ctr.Increment(ctx, key, 1) + if err != nil { + t.Errorf("worker %d: increment error: %v", id, err) + return + } + } else { + _, err := ctr.Get(ctx, key) + if err != nil { + t.Errorf("worker %d: get error: %v", id, err) + return + } + } + } + }(i) + } + + wg.Wait() + + // Calculate expected value: each worker does operationsPerWorker/2 increments + expectedValue := int64(numWorkers * (operationsPerWorker / 2)) + + // Verify final value + value, err := ctr.Get(ctx, key) + require.NoError(t, err) + assert.Equal(t, expectedValue, value, "Final value after interleaved operations doesn't match expected") + }) + + // Test increments with TTL in parallel + t.Run("ConcurrentTTLIncrements", func(t *testing.T) { + key := "test:concurrent:ttl" + numWorkers := 10 + ttl := 3 * time.Second + + var wg sync.WaitGroup + wg.Add(numWorkers) + + for i := 0; i < numWorkers; i++ { + go func() { + defer wg.Done() + _, err := ctr.Increment(ctx, key, 1, ttl) + if err != nil { + t.Errorf("increment with TTL error: %v", err) + } + }() + } + + wg.Wait() + + // Verify value right after increments + value, err := ctr.Get(ctx, key) + require.NoError(t, err) + assert.Equal(t, int64(numWorkers), value) + + // Wait for TTL to expire + time.Sleep(4 * time.Second) + + // Key should be gone or zero + value, err = ctr.Get(ctx, key) + require.NoError(t, err) + assert.Equal(t, int64(0), value, "Counter should be zero after TTL expiry") + }) +} + +func TestRedisCounterConnection(t *testing.T) { + t.Run("InvalidURL", func(t *testing.T) { + // Test with invalid URL + _, err := NewRedis(RedisConfig{ + RedisURL: "invalid://url", + Logger: logging.New(), + }) + require.Error(t, err) + }) + + t.Run("ConnectionRefused", func(t *testing.T) { + // Test with non-existent Redis server + _, err := NewRedis(RedisConfig{ + RedisURL: "redis://localhost:12345", + Logger: logging.New(), + }) + require.Error(t, err) + }) + + t.Run("EmptyURL", func(t *testing.T) { + // Test with empty URL + _, err := NewRedis(RedisConfig{ + RedisURL: "", + Logger: logging.New(), + }) + require.Error(t, err) + }) +} + +func TestRedisCounterMultiGet(t *testing.T) { + ctx := context.Background() + containers := containers.New(t) + _, redisURL, _ := containers.RunRedis() + + // Create a Redis counter + ctr, err := NewRedis(RedisConfig{ + RedisURL: redisURL, + Logger: logging.New(), + }) + require.NoError(t, err) + defer ctr.Close() + + // Set up some test data + testData := map[string]int64{ + "multi:key1": 10, + "multi:key2": 20, + "multi:key3": 30, + "multi:key4": 40, + "multi:key5": 50, + } + + // Initialize counters + for key, value := range testData { + _, err := ctr.Increment(ctx, key, value) + require.NoError(t, err) + } + + t.Run("MultiGetAllExisting", func(t *testing.T) { + keys := []string{"multi:key1", "multi:key2", "multi:key3", "multi:key4", "multi:key5"} + values, err := ctr.MultiGet(ctx, keys) + require.NoError(t, err) + + // Verify all values match expected + for key, expectedValue := range testData { + value, exists := values[key] + assert.True(t, exists, "Key %s should exist in results", key) + assert.Equal(t, expectedValue, value, "Value for key %s should match", key) + } + }) + + t.Run("MultiGetMixedExistingAndNonExisting", func(t *testing.T) { + keys := []string{"multi:key1", "multi:nonexistent1", "multi:key3", "multi:nonexistent2"} + values, err := ctr.MultiGet(ctx, keys) + require.NoError(t, err) + + // Verify existing values + assert.Equal(t, int64(10), values["multi:key1"]) + assert.Equal(t, int64(30), values["multi:key3"]) + + // Verify non-existing values are 0 + assert.Equal(t, int64(0), values["multi:nonexistent1"]) + assert.Equal(t, int64(0), values["multi:nonexistent2"]) + }) + + t.Run("MultiGetEmpty", func(t *testing.T) { + values, err := ctr.MultiGet(ctx, []string{}) + require.NoError(t, err) + assert.Empty(t, values, "Result should be empty for empty keys list") + }) + + t.Run("MultiGetNonExisting", func(t *testing.T) { + keys := []string{"multi:nonexistent1", "multi:nonexistent2", "multi:nonexistent3"} + values, err := ctr.MultiGet(ctx, keys) + require.NoError(t, err) + + // All values should be 0 + for _, key := range keys { + assert.Equal(t, int64(0), values[key]) + } + }) + + t.Run("MultiGetLarge", func(t *testing.T) { + // Set up 100 counters + largeTestData := make(map[string]int64) + var largeKeys []string + + for i := 0; i < 100; i++ { + key := fmt.Sprintf("multi:large:%d", i) + largeTestData[key] = int64(i) + largeKeys = append(largeKeys, key) + _, err := ctr.Increment(ctx, key, int64(i)) + require.NoError(t, err) + } + + // Get all values + values, err := ctr.MultiGet(ctx, largeKeys) + require.NoError(t, err) + + // Verify all values match expected + assert.Equal(t, len(largeTestData), len(values)) + for key, expectedValue := range largeTestData { + assert.Equal(t, expectedValue, values[key]) + } + }) + + t.Run("ConcurrentMultiGet", func(t *testing.T) { + var wg sync.WaitGroup + numGoroutines := 10 + keys := []string{"multi:key1", "multi:key2", "multi:key3", "multi:key4", "multi:key5"} + + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + values, err := ctr.MultiGet(ctx, keys) + if err != nil { + t.Errorf("MultiGet error: %v", err) + return + } + + // Verify key counts + if len(values) != len(keys) { + t.Errorf("Expected %d values, got %d", len(keys), len(values)) + } + } + }() + } + + wg.Wait() + }) +} diff --git a/go/pkg/rpc/rpc.go b/go/pkg/prometheus/promhttp.go similarity index 75% rename from go/pkg/rpc/rpc.go rename to go/pkg/prometheus/promhttp.go index c92d870f95..5aeadbe40d 100644 --- a/go/pkg/rpc/rpc.go +++ b/go/pkg/prometheus/promhttp.go @@ -1,4 +1,4 @@ -package rpc +package prometheus import ( "context" @@ -7,12 +7,10 @@ import ( "sync" "time" - "connectrpc.com/connect" - "connectrpc.com/otelconnect" - "github.com/unkeyed/unkey/go/gen/proto/ratelimit/v1/ratelimitv1connect" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/unkeyed/unkey/go/pkg/fault" "github.com/unkeyed/unkey/go/pkg/otel/logging" - "github.com/unkeyed/unkey/go/pkg/otel/tracing" ) type Server struct { @@ -22,14 +20,10 @@ type Server struct { isListening bool srv *http.Server mux *http.ServeMux - // Define fields for the server } type Config struct { Logger logging.Logger - - RatelimitService ratelimitv1connect.RatelimitServiceHandler - // Define fields for the configuration } func New(config Config) (*Server, error) { @@ -53,31 +47,17 @@ func New(config Config) (*Server, error) { WriteTimeout: 20 * time.Second, } - interceptor, err := otelconnect.NewInterceptor( - otelconnect.WithTracerProvider(tracing.GetGlobalTraceProvider()), - otelconnect.WithTrustRemote(), - otelconnect.WithoutServerPeerAttributes(), - ) - if err != nil { - return nil, err - } - - if config.RatelimitService != nil { - mux.Handle( - ratelimitv1connect.NewRatelimitServiceHandler( - config.RatelimitService, - connect.WithInterceptors(interceptor), - ), - ) - } - - return &Server{ + s := &Server{ mu: sync.Mutex{}, logger: config.Logger, isListening: false, srv: srv, mux: mux, - }, nil + } + + mux.Handle("GET /metrics", promhttp.Handler()) + + return s, nil } // Listen starts the RPC server on the specified address. @@ -89,7 +69,7 @@ func New(config Config) (*Server, error) { // // Start server in a goroutine to allow for graceful shutdown // go func() { // if err := server.Listen(ctx, ":8080"); err != nil { -// log.Printf("rpc stopped: %v", err) +// log.Printf("server stopped: %v", err) // } // }() func (s *Server) Listen(ctx context.Context, addr string) error { @@ -104,7 +84,7 @@ func (s *Server) Listen(ctx context.Context, addr string) error { s.srv.Addr = addr - s.logger.Info("listening", "srv", "rpc", "addr", addr) + s.logger.Info("listening", "srv", "prometheus", "addr", addr) err := s.srv.ListenAndServe() if err != nil && !errors.Is(err, http.ErrServerClosed) { diff --git a/go/pkg/prometheus/servicediscovery.go b/go/pkg/prometheus/servicediscovery.go deleted file mode 100644 index 21d62e3bac..0000000000 --- a/go/pkg/prometheus/servicediscovery.go +++ /dev/null @@ -1,202 +0,0 @@ -package prometheus - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net" - "net/http" - "sync" - "time" - - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" - "github.com/prometheus/client_golang/prometheus/promhttp" - - "github.com/unkeyed/unkey/go/pkg/discovery" - "github.com/unkeyed/unkey/go/pkg/fault" - "github.com/unkeyed/unkey/go/pkg/otel/logging" -) - -// https://prometheus.io/docs/prometheus/latest/http_sd -// [ -// -// { -// "targets": [ "", ... ], -// "labels": { -// "": "", ... -// } -// }, -// ... -// -// ] -type ServiceDiscoveryResponseElement struct { - Targets []string `json:"targets"` - Labels map[string]string `json:"labels,omitempty"` -} - -type ServiceDiscoveryResponse = []ServiceDiscoveryResponseElement - -type Server struct { - mu sync.Mutex - - logger logging.Logger - isListening bool - srv *http.Server - mux *http.ServeMux - - sd discovery.Discoverer -} - -type Config struct { - Logger logging.Logger - Discovery discovery.Discoverer -} - -func New(config Config) (*Server, error) { - mux := http.NewServeMux() - srv := &http.Server{ - Handler: mux, - // See https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/ - // - // > # http.ListenAndServe is doing it wrong - // > Incidentally, this means that the package-level convenience functions that bypass http.Server - // > like http.ListenAndServe, http.ListenAndServeTLS and http.Serve are unfit for public Internet - // > Servers. - // > - // > Those functions leave the Timeouts to their default off value, with no way of enabling them, - // > so if you use them you'll soon be leaking connections and run out of file descriptors. I've - // > made this mistake at least half a dozen times. - // > - // > Instead, create a http.Server instance with ReadTimeout and WriteTimeout and use its - // > corresponding methods, like in the example a few paragraphs above. - ReadTimeout: 10 * time.Second, - WriteTimeout: 20 * time.Second, - } - - s := &Server{ - mu: sync.Mutex{}, - logger: config.Logger, - isListening: false, - srv: srv, - mux: mux, - sd: config.Discovery, - } - - mux.Handle("GET /metrics", promhttp.Handler()) - - // dummy metric for the demo - sdCounterDummyMetric := promauto.NewCounter(prometheus.CounterOpts{ - Name: "sd_called_total", - Help: "Demo counter just so we have something", - }) - - mux.HandleFunc("GET /sd", func(w http.ResponseWriter, r *http.Request) { - sdCounterDummyMetric.Add(1.0) - _, port, err := net.SplitHostPort(s.srv.Addr) - if err != nil { - s.internalServerError(err, w) - return - } - - addrs, err := s.sd.Discover() - if err != nil { - s.internalServerError(err, w) - return - } - - e := ServiceDiscoveryResponseElement{ - Targets: []string{}, - Labels: map[string]string{ - // I don't know why they're prefixed but that's what the docs do - "__meta_region": "todo", - "__meta_platform": "aws", - }, - } - - for _, addr := range addrs { - - e.Targets = append(e.Targets, fmt.Sprintf("%s:%s", addr, port)) - } - - w.Header().Add("Content-Type", "application/json") - b, err := json.Marshal(ServiceDiscoveryResponse{e}) - if err != nil { - s.internalServerError(err, w) - return - } - - _, err = w.Write(b) - if err != nil { - s.logger.Error("unable to write prometheus /sd response", - "err", err.Error(), - ) - } - }) - - return s, nil -} - -// Listen starts the RPC server on the specified address. -// This method blocks until the server shuts down or encounters an error. -// Once listening, the server will not start again if Listen is called multiple times. -// -// Example: -// -// // Start server in a goroutine to allow for graceful shutdown -// go func() { -// if err := server.Listen(ctx, ":8080"); err != nil { -// log.Printf("server stopped: %v", err) -// } -// }() -func (s *Server) Listen(ctx context.Context, addr string) error { - s.mu.Lock() - if s.isListening { - s.logger.Warn("already listening") - s.mu.Unlock() - return nil - } - s.isListening = true - s.mu.Unlock() - - s.srv.Addr = addr - - s.logger.Info("listening", "srv", "prometheus", "addr", addr) - - err := s.srv.ListenAndServe() - if err != nil && !errors.Is(err, http.ErrServerClosed) { - return fault.Wrap(err, fault.WithDesc("listening failed", "")) - } - return nil -} - -// Shutdown gracefully stops the RPC server, allowing in-flight requests -// to complete before returning. -// -// Example: -// -// // Handle shutdown signal -// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) -// defer cancel() -// if err := server.Shutdown(ctx); err != nil { -// log.Printf("server shutdown error: %v", err) -// } -func (s *Server) Shutdown() error { - s.mu.Lock() - defer s.mu.Unlock() - err := s.srv.Close() - if err != nil { - return fault.Wrap(err) - } - return nil -} - -func (s *Server) internalServerError(err error, w http.ResponseWriter) { - s.logger.Error(err.Error()) - w.WriteHeader(http.StatusInternalServerError) - _, wErr := w.Write([]byte(err.Error())) - if wErr != nil { - s.logger.Error("writing response failed", "err", wErr.Error()) - } -} diff --git a/go/pkg/testutil/containers/api.go b/go/pkg/testutil/containers/api.go index 0b0ba5e292..be44824e37 100644 --- a/go/pkg/testutil/containers/api.go +++ b/go/pkg/testutil/containers/api.go @@ -41,7 +41,7 @@ func (c *Containers) RunAPI(nodes int, mysqlDSN string) Cluster { require.NoError(c.t, err) c.t.Logf("building %s took %s", imageName, time.Since(t0)) - _, _, redisAddr := c.RunRedis() + _, _, redisUrl := c.RunRedis() cluster := Cluster{ Instances: []*dockertest.Resource{}, @@ -55,17 +55,15 @@ func (c *Containers) RunAPI(nodes int, mysqlDSN string) Cluster { Name: instanceId, Repository: imageName, Networks: []*dockertest.Network{c.network}, - ExposedPorts: []string{"7070", "9090", "9091"}, + ExposedPorts: []string{"7070"}, Cmd: []string{"api"}, Env: []string{ "UNKEY_HTTP_PORT=7070", - "UNKEY_CLUSTER=true", - "UNKEY_CLUSTER_GOSSIP_PORT=9090", - "UNKEY_CLUSTER_RPC_PORT=9091", "UNKEY_OTEL=true", "OTEL_EXPORTER_OTLP_ENDPOINT=http://otel:4318", "OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf", - fmt.Sprintf("UNKEY_CLUSTER_DISCOVERY_REDIS_URL=redis://%s", redisAddr), + "UNKEY_TEST_MODE=true", + fmt.Sprintf("UNKEY_REDIS_URL=%s", redisUrl), fmt.Sprintf("UNKEY_DATABASE_PRIMARY_DSN=%s", mysqlDSN), }, } diff --git a/go/pkg/testutil/containers/redis.go b/go/pkg/testutil/containers/redis.go index 09971d175c..91c16cc308 100644 --- a/go/pkg/testutil/containers/redis.go +++ b/go/pkg/testutil/containers/redis.go @@ -83,18 +83,16 @@ func (c *Containers) RunRedis() (client *redis.Client, hostAddr, dockerAddr stri require.NoError(c.t, c.pool.Purge(resource)) }) - hostAddr = fmt.Sprintf("localhost:%s", resource.GetPort("6379/tcp")) - dockerAddr = fmt.Sprintf("%s:6379", resource.GetIPInNetwork(c.network)) + hostAddr = fmt.Sprintf("redis://localhost:%s", resource.GetPort("6379/tcp")) + dockerAddr = fmt.Sprintf("redis://%s:6379", resource.GetIPInNetwork(c.network)) + + opts, err := redis.ParseURL(hostAddr) + require.NoError(c.t, err) // Configure the Redis client // nolint:exhaustruct - client = redis.NewClient(&redis.Options{ - Addr: hostAddr, - Password: "", // no password set - DB: 0, // use default DB - }) + client = redis.NewClient(opts) - // Wait for the Redis server to be ready ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() diff --git a/go/pkg/testutil/http.go b/go/pkg/testutil/http.go index 8674858c93..eb4c522805 100644 --- a/go/pkg/testutil/http.go +++ b/go/pkg/testutil/http.go @@ -15,7 +15,7 @@ import ( "github.com/unkeyed/unkey/go/internal/services/ratelimit" "github.com/unkeyed/unkey/go/pkg/clickhouse" "github.com/unkeyed/unkey/go/pkg/clock" - "github.com/unkeyed/unkey/go/pkg/cluster" + "github.com/unkeyed/unkey/go/pkg/counter" "github.com/unkeyed/unkey/go/pkg/db" "github.com/unkeyed/unkey/go/pkg/otel/logging" "github.com/unkeyed/unkey/go/pkg/testutil/containers" @@ -54,6 +54,8 @@ func NewHarness(t *testing.T) *Harness { dsn, _ := cont.RunMySQL() + _, redisUrl, _ := cont.RunRedis() + db, err := db.New(db.Config{ Logger: logger, PrimaryDSN: dsn, @@ -70,6 +72,9 @@ func NewHarness(t *testing.T) *Harness { srv, err := zen.New(zen.Config{ InstanceID: "test", Logger: logger, + Flags: &zen.Flags{ + TestMode: true, + }, }) require.NoError(t, err) @@ -94,10 +99,16 @@ func NewHarness(t *testing.T) *Harness { }) require.NoError(t, err) + ctr, err := counter.NewRedis(counter.RedisConfig{ + RedisURL: redisUrl, + Logger: logger, + }) + require.NoError(t, err) + ratelimitService, err := ratelimit.New(ratelimit.Config{ Logger: logger, - Cluster: cluster.NewNoop("test", "localhost"), Clock: clk, + Counter: ctr, }) require.NoError(t, err) diff --git a/go/pkg/zen/middleware_errors.go b/go/pkg/zen/middleware_errors.go index 22c57767de..e0c86ae025 100644 --- a/go/pkg/zen/middleware_errors.go +++ b/go/pkg/zen/middleware_errors.go @@ -34,11 +34,6 @@ func WithErrorHandling(logger logging.Logger) Middleware { return nil } - logger.Error("api error", - "error", err.Error(), - "publicMessage", fault.UserFacingMessage(err), - ) - // errorSteps := fault.Flatten(err) // if len(errorSteps) > 0 { @@ -171,6 +166,11 @@ func WithErrorHandling(logger logging.Logger) Middleware { break } + logger.Error("api error", + "error", err.Error(), + "requestId", s.RequestID(), + "publicMessage", fault.UserFacingMessage(err), + ) return s.JSON(http.StatusInternalServerError, openapi.InternalServerErrorResponse{ Meta: openapi.Meta{ RequestId: s.RequestID(), diff --git a/go/pkg/zen/server.go b/go/pkg/zen/server.go index 65221ac74e..7bf71bd674 100644 --- a/go/pkg/zen/server.go +++ b/go/pkg/zen/server.go @@ -25,10 +25,17 @@ type Server struct { isListening bool mux *http.ServeMux srv *http.Server + flags Flags sessions sync.Pool } +// Flags configures the behavior of a Server instance. +type Flags struct { + // TestMode enables test mode, accepting certain headers from untrusted clients such as fake times for testing purposes. + TestMode bool +} + // Config configures the behavior of a Server instance. type Config struct { // InstanceID uniquely identifies this server instance, useful for logging and tracing. @@ -36,6 +43,8 @@ type Config struct { // Logger provides structured logging for the server. If nil, logging is disabled. Logger logging.Logger + + Flags *Flags } // New creates a new server with the provided configuration. @@ -75,12 +84,19 @@ func New(config Config) (*Server, error) { WriteTimeout: 20 * time.Second, } + flags := Flags{ + TestMode: false, + } + if config.Flags != nil { + flags = *config.Flags + } s := &Server{ mu: sync.Mutex{}, logger: config.Logger, isListening: false, mux: mux, srv: srv, + flags: flags, sessions: sync.Pool{ New: func() any { return &Session{ @@ -122,6 +138,10 @@ func (s *Server) Mux() *http.ServeMux { return s.mux } +func (s *Server) Flags() Flags { + return s.flags +} + // Listen starts the HTTP server on the specified address. // This method blocks until the server shuts down or encounters an error. // Once listening, the server will not start again if Listen is called multiple times.