From 0c86605970ed2c4d7b39fcb54b085d08eb9f7b4a Mon Sep 17 00:00:00 2001 From: akshaydeo Date: Mon, 20 Apr 2026 08:54:25 +0530 Subject: [PATCH] atomic updates to budgets and ratelimits --- plugins/governance/store.go | 581 +++++++++++-------- plugins/governance/store_concurrency_test.go | 124 ++++ 2 files changed, 456 insertions(+), 249 deletions(-) create mode 100644 plugins/governance/store_concurrency_test.go diff --git a/plugins/governance/store.go b/plugins/governance/store.go index 26f317d96b..127a21230f 100644 --- a/plugins/governance/store.go +++ b/plugins/governance/store.go @@ -99,13 +99,18 @@ type BudgetAndRateLimitStatus struct { type GovernanceStore interface { GetGovernanceData(ctx context.Context) *GovernanceData GetVirtualKey(ctx context.Context, vkValue string) (*configstoreTables.TableVirtualKey, bool) - // Budget crud + // Budget crud. + // UpsertBudgetConfig preserves in-memory CurrentUsage/LastReset on replacement — + // use it for every config publish (fresh load or admin edit) so a concurrent + // BumpBudgetUsage increment is never clobbered. LoadBudget(ctx context.Context, budgetID string) *configstoreTables.TableBudget - StoreBudget(ctx context.Context, budgetID string, budget *configstoreTables.TableBudget) + UpsertBudgetConfig(ctx context.Context, budgetID string, config *configstoreTables.TableBudget) DeleteBudget(ctx context.Context, budgetID string) - // Rate limit crud + // Rate limit crud. UpsertRateLimitConfig carries in-memory counter state + // (token + request CurrentUsage/LastReset) forward across replacements — + // same rationale as UpsertBudgetConfig. LoadRateLimit(ctx context.Context, rateLimitID string) *configstoreTables.TableRateLimit - StoreRateLimit(ctx context.Context, rateLimitID string, rateLimit *configstoreTables.TableRateLimit) + UpsertRateLimitConfig(ctx context.Context, rateLimitID string, config *configstoreTables.TableRateLimit) DeleteRateLimit(ctx context.Context, rateLimitID string) // Provider-level governance checks CheckProviderBudget(ctx context.Context, request *EvaluationRequest, baselines map[string]float64) (Decision, error) @@ -221,9 +226,41 @@ func (gs *LocalGovernanceStore) LoadBudget(ctx context.Context, budgetID string) return nil } -// StoreBudget stores a budget in the local store. -func (gs *LocalGovernanceStore) StoreBudget(ctx context.Context, budgetID string, budget *configstoreTables.TableBudget) { - gs.budgets.Store(budgetID, budget) +// UpsertBudgetConfig publishes a budget config under budgetID, preserving the +// in-memory CurrentUsage and LastReset from any prior snapshot so a concurrent +// BumpBudgetUsage or ResetBudgetAt is never clobbered by a config replacement. +// First-writes (no prior entry) are handled via sync.Map.LoadOrStore so +// simultaneous first-writers collapse to a single insertion and the late +// arrival re-enters the CAS loop against the winner's snapshot. +// +// This method replaces the former blind StoreBudget: every caller installing +// a budget — whether fresh load or config replacement — should funnel through +// here so counters are never clobbered by an admin edit racing with a usage +// increment. +func (gs *LocalGovernanceStore) UpsertBudgetConfig(ctx context.Context, budgetID string, config *configstoreTables.TableBudget) { + if config == nil { + return + } + for { + raw, exists := gs.budgets.Load(budgetID) + if !exists { + if _, loaded := gs.budgets.LoadOrStore(budgetID, config); !loaded { + return + } + continue + } + old, ok := raw.(*configstoreTables.TableBudget) + if !ok || old == nil { + gs.budgets.Store(budgetID, config) + return + } + merged := *config + merged.CurrentUsage = old.CurrentUsage + merged.LastReset = old.LastReset + if gs.budgets.CompareAndSwap(budgetID, raw, &merged) { + return + } + } } // DeleteBudget deletes a budget from the local store. @@ -241,9 +278,36 @@ func (gs *LocalGovernanceStore) LoadRateLimit(ctx context.Context, rateLimitID s return nil } -// StoreRateLimit stores a rate limit in the local store. -func (gs *LocalGovernanceStore) StoreRateLimit(ctx context.Context, rateLimitID string, rateLimit *configstoreTables.TableRateLimit) { - gs.rateLimits.Store(rateLimitID, rateLimit) +// UpsertRateLimitConfig publishes a rate-limit config under rateLimitID, +// preserving in-memory token and request counter state (TokenCurrentUsage / +// TokenLastReset / RequestCurrentUsage / RequestLastReset) from any prior +// snapshot. Same CAS-retry contract as UpsertBudgetConfig. +func (gs *LocalGovernanceStore) UpsertRateLimitConfig(ctx context.Context, rateLimitID string, config *configstoreTables.TableRateLimit) { + if config == nil { + return + } + for { + raw, exists := gs.rateLimits.Load(rateLimitID) + if !exists { + if _, loaded := gs.rateLimits.LoadOrStore(rateLimitID, config); !loaded { + return + } + continue + } + old, ok := raw.(*configstoreTables.TableRateLimit) + if !ok || old == nil { + gs.rateLimits.Store(rateLimitID, config) + return + } + merged := *config + merged.TokenCurrentUsage = old.TokenCurrentUsage + merged.TokenLastReset = old.TokenLastReset + merged.RequestCurrentUsage = old.RequestCurrentUsage + merged.RequestLastReset = old.RequestLastReset + if gs.rateLimits.CompareAndSwap(rateLimitID, raw, &merged) { + return + } + } } // DeleteRateLimit deletes a rate limit from the local store. @@ -251,6 +315,162 @@ func (gs *LocalGovernanceStore) DeleteRateLimit(ctx context.Context, rateLimitID gs.rateLimits.Delete(rateLimitID) } +// BumpBudgetUsage atomically increments CurrentUsage on the budget identified +// by budgetID and, as a side effect, zeros CurrentUsage / advances LastReset +// when the rolling ResetDuration has elapsed. Uses sync.Map.CompareAndSwap so +// concurrent callers on the same budget never drop increments — a lost CAS +// retries against the winner's snapshot. No-op when the budget is absent. +// +// This is the serialisation point for every usage increment: callers MUST +// funnel through this method (directly or via one of the higher-level +// Update*BudgetUsageInMemory wrappers) rather than doing a plain +// Load → clone → mutate → Store, which races. +func (gs *LocalGovernanceStore) BumpBudgetUsage(ctx context.Context, budgetID string, cost float64) error { + for { + raw, exists := gs.budgets.Load(budgetID) + if !exists || raw == nil { + return nil + } + old, ok := raw.(*configstoreTables.TableBudget) + if !ok || old == nil { + return nil + } + clone := *old + now := time.Now() + if clone.ResetDuration != "" { + if duration, err := configstoreTables.ParseDuration(clone.ResetDuration); err == nil { + if now.Sub(clone.LastReset) >= duration { + clone.CurrentUsage = 0 + clone.LastReset = now + } + } + } + clone.CurrentUsage += cost + if gs.budgets.CompareAndSwap(budgetID, raw, &clone) { + return nil + } + } +} + +// BumpRateLimitUsage atomically increments the token and/or request counters on +// the rate limit identified by rateLimitID and, as a side effect, zeros the +// relevant counter / advances its LastReset when the rolling +// TokenResetDuration / RequestResetDuration has elapsed. Same CAS-retry +// contract as BumpBudgetUsage — no increment is ever dropped under +// concurrent callers. No-op when the rate limit is absent. +func (gs *LocalGovernanceStore) BumpRateLimitUsage(ctx context.Context, rateLimitID string, tokensUsed int64, shouldUpdateTokens, shouldUpdateRequests bool) error { + for { + raw, exists := gs.rateLimits.Load(rateLimitID) + if !exists || raw == nil { + return nil + } + old, ok := raw.(*configstoreTables.TableRateLimit) + if !ok || old == nil { + return nil + } + clone := *old + now := time.Now() + if clone.TokenResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*clone.TokenResetDuration); err == nil { + if now.Sub(clone.TokenLastReset) >= duration { + clone.TokenCurrentUsage = 0 + clone.TokenLastReset = now + } + } + } + if clone.RequestResetDuration != nil { + if duration, err := configstoreTables.ParseDuration(*clone.RequestResetDuration); err == nil { + if now.Sub(clone.RequestLastReset) >= duration { + clone.RequestCurrentUsage = 0 + clone.RequestLastReset = now + } + } + } + if shouldUpdateTokens { + clone.TokenCurrentUsage += tokensUsed + } + if shouldUpdateRequests { + clone.RequestCurrentUsage++ + } + if gs.rateLimits.CompareAndSwap(rateLimitID, raw, &clone) { + return nil + } + } +} + +// ResetBudgetAt atomically zeros the budget's CurrentUsage and advances its +// LastReset to newLastReset, provided the currently-stored budget has an +// older LastReset. Returns the reset budget and true when the CAS succeeds; +// (nil, false) if the budget is absent or another writer has already advanced +// LastReset to at least newLastReset. Callers (e.g. ResetExpiredBudgetsInMemory) +// use the false return to skip the DB-persistence and reference-refresh work +// that would otherwise be redundant. +func (gs *LocalGovernanceStore) ResetBudgetAt(ctx context.Context, budgetID string, newLastReset time.Time) (*configstoreTables.TableBudget, bool) { + for { + raw, exists := gs.budgets.Load(budgetID) + if !exists || raw == nil { + return nil, false + } + old, ok := raw.(*configstoreTables.TableBudget) + if !ok || old == nil { + return nil, false + } + if !old.LastReset.Before(newLastReset) { + // Someone else already advanced LastReset past ours, or the reset + // window hasn't actually opened relative to the stored snapshot. + return nil, false + } + clone := *old + clone.CurrentUsage = 0 + clone.LastReset = newLastReset + if gs.budgets.CompareAndSwap(budgetID, raw, &clone) { + return &clone, true + } + } +} + +// ResetRateLimitAt atomically resets one or both rate-limit counters on the +// rate limit identified by rateLimitID. A non-nil tokenNewLastReset resets the +// token counter and advances TokenLastReset; similarly for +// requestNewLastReset. Each reset is conditional on the corresponding +// LastReset currently being strictly older than the supplied target, so +// concurrent resetters collapse into a single successful write. Returns the +// updated snapshot and true when at least one counter was reset; (nil, false) +// otherwise. +func (gs *LocalGovernanceStore) ResetRateLimitAt(ctx context.Context, rateLimitID string, tokenNewLastReset, requestNewLastReset *time.Time) (*configstoreTables.TableRateLimit, bool) { + if tokenNewLastReset == nil && requestNewLastReset == nil { + return nil, false + } + for { + raw, exists := gs.rateLimits.Load(rateLimitID) + if !exists || raw == nil { + return nil, false + } + old, ok := raw.(*configstoreTables.TableRateLimit) + if !ok || old == nil { + return nil, false + } + clone := *old + didReset := false + if tokenNewLastReset != nil && old.TokenLastReset.Before(*tokenNewLastReset) { + clone.TokenCurrentUsage = 0 + clone.TokenLastReset = *tokenNewLastReset + didReset = true + } + if requestNewLastReset != nil && old.RequestLastReset.Before(*requestNewLastReset) { + clone.RequestCurrentUsage = 0 + clone.RequestLastReset = *requestNewLastReset + didReset = true + } + if !didReset { + return nil, false + } + if gs.rateLimits.CompareAndSwap(rateLimitID, raw, &clone) { + return &clone, true + } + } +} + // GetGovernanceData returns a snapshot of the current governance data. func (gs *LocalGovernanceStore) GetGovernanceData(ctx context.Context) *GovernanceData { refreshVKAssociations := func(vk *configstoreTables.TableVirtualKey) { @@ -992,37 +1212,11 @@ func (gs *LocalGovernanceStore) UpdateVirtualKeyBudgetUsageInMemory(ctx context. if vk == nil { return fmt.Errorf("virtual key cannot be nil") } - // Collect budget IDs using fast in-memory lookup instead of DB queries budgetIDs := gs.collectBudgetIDsFromMemory(ctx, vk, provider) - now := time.Now() for _, budgetID := range budgetIDs { - // Update in-memory cache for next read (lock-free) - if cachedBudgetValue, exists := gs.budgets.Load(budgetID); exists && cachedBudgetValue != nil { - if cachedBudget, ok := cachedBudgetValue.(*configstoreTables.TableBudget); ok && cachedBudget != nil { - // Clone FIRST to avoid race conditions - clone := *cachedBudget - oldUsage := clone.CurrentUsage - - // Check if budget needs reset (in-memory check) - operate on clone - if clone.ResetDuration != "" { - if duration, err := configstoreTables.ParseDuration(clone.ResetDuration); err == nil { - if now.Sub(clone.LastReset) >= duration { - clone.CurrentUsage = 0 - clone.LastReset = now - gs.logger.Debug("UpdateVirtualKeyBudgetUsageInMemory: Budget %s was reset (expired, duration: %v)", budgetID, duration) - } - } - } - - // Update the clone - clone.CurrentUsage += cost - gs.budgets.Store(budgetID, &clone) - gs.logger.Debug("UpdateVirtualKeyBudgetUsageInMemory: Updated budget %s: %.4f -> %.4f (added %.4f)", - budgetID, oldUsage, clone.CurrentUsage, cost) - } - } else { - gs.logger.Warn("UpdateVirtualKeyBudgetUsageInMemory: Budget %s not found in local store", budgetID) + if err := gs.BumpBudgetUsage(ctx, budgetID, cost); err != nil { + return err } } return nil @@ -1030,36 +1224,14 @@ func (gs *LocalGovernanceStore) UpdateVirtualKeyBudgetUsageInMemory(ctx context. // UpdateProviderAndModelBudgetUsageInMemory performs atomic budget updates for both provider-level and model-level configs (in memory) func (gs *LocalGovernanceStore) UpdateProviderAndModelBudgetUsageInMemory(ctx context.Context, model string, provider schemas.ModelProvider, cost float64) error { - now := time.Now() - - // Helper function to update a budget by ID - updateBudget := func(budgetID string) { - if cachedBudgetValue, exists := gs.budgets.Load(budgetID); exists && cachedBudgetValue != nil { - if cachedBudget, ok := cachedBudgetValue.(*configstoreTables.TableBudget); ok && cachedBudget != nil { - // Clone FIRST to avoid race conditions - clone := *cachedBudget - // Check if budget needs reset (in-memory check) - operate on clone - if clone.ResetDuration != "" { - if duration, err := configstoreTables.ParseDuration(clone.ResetDuration); err == nil { - if now.Sub(clone.LastReset) >= duration { - clone.CurrentUsage = 0 - clone.LastReset = now - } - } - } - // Update the clone - clone.CurrentUsage += cost - gs.budgets.Store(budgetID, &clone) - } - } - } - // 1. Update provider-level budget (if provider is set) if provider != "" { providerKey := string(provider) if value, exists := gs.providers.Load(providerKey); exists && value != nil { if providerTable, ok := value.(*configstoreTables.TableProvider); ok && providerTable != nil && providerTable.BudgetID != nil { - updateBudget(*providerTable.BudgetID) + if err := gs.BumpBudgetUsage(ctx, *providerTable.BudgetID, cost); err != nil { + return err + } } } } @@ -1070,7 +1242,9 @@ func (gs *LocalGovernanceStore) UpdateProviderAndModelBudgetUsageInMemory(ctx co key := fmt.Sprintf("%s:%s", model, string(provider)) if value, exists := gs.modelConfigs.Load(key); exists && value != nil { if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil && mc.BudgetID != nil { - updateBudget(*mc.BudgetID) + if err := gs.BumpBudgetUsage(ctx, *mc.BudgetID, cost); err != nil { + return err + } } } } @@ -1078,7 +1252,9 @@ func (gs *LocalGovernanceStore) UpdateProviderAndModelBudgetUsageInMemory(ctx co // Always check model-only config (if exists) - regardless of whether model+provider config exists // Uses findModelOnlyConfig for cross-provider model name normalization if mc, _ := gs.findModelOnlyConfig(ctx, model); mc != nil && mc.BudgetID != nil { - updateBudget(*mc.BudgetID) + if err := gs.BumpBudgetUsage(ctx, *mc.BudgetID, cost); err != nil { + return err + } } return nil @@ -1090,51 +1266,16 @@ func (gs *LocalGovernanceStore) UpdateUserBudgetUsageInMemory(ctx context.Contex return nil } -// UpdateProviderAndModelRateLimitUsageInMemory updates rate limit counters for both provider-level and model-level rate limits (lock-free) +// UpdateProviderAndModelRateLimitUsageInMemory updates rate limit counters for both provider-level and model-level rate limits. func (gs *LocalGovernanceStore) UpdateProviderAndModelRateLimitUsageInMemory(ctx context.Context, model string, provider schemas.ModelProvider, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error { - now := time.Now() - - // Helper function to update a rate limit by ID - updateRateLimit := func(rateLimitID string) { - if cachedRateLimitValue, exists := gs.rateLimits.Load(rateLimitID); exists && cachedRateLimitValue != nil { - if cachedRateLimit, ok := cachedRateLimitValue.(*configstoreTables.TableRateLimit); ok && cachedRateLimit != nil { - // Clone FIRST to avoid race conditions - clone := *cachedRateLimit - // Check if rate limit needs reset (in-memory check) - operate on clone - if clone.TokenResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*clone.TokenResetDuration); err == nil { - if now.Sub(clone.TokenLastReset) >= duration { - clone.TokenCurrentUsage = 0 - clone.TokenLastReset = now - } - } - } - if clone.RequestResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*clone.RequestResetDuration); err == nil { - if now.Sub(clone.RequestLastReset) >= duration { - clone.RequestCurrentUsage = 0 - clone.RequestLastReset = now - } - } - } - // Update the clone - if shouldUpdateTokens { - clone.TokenCurrentUsage += tokensUsed - } - if shouldUpdateRequests { - clone.RequestCurrentUsage += 1 - } - gs.rateLimits.Store(rateLimitID, &clone) - } - } - } - // 1. Update provider-level rate limit (if provider is set) if provider != "" { providerKey := string(provider) if value, exists := gs.providers.Load(providerKey); exists && value != nil { if providerTable, ok := value.(*configstoreTables.TableProvider); ok && providerTable != nil && providerTable.RateLimitID != nil { - updateRateLimit(*providerTable.RateLimitID) + if err := gs.BumpRateLimitUsage(ctx, *providerTable.RateLimitID, tokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil { + return err + } } } } @@ -1145,7 +1286,9 @@ func (gs *LocalGovernanceStore) UpdateProviderAndModelRateLimitUsageInMemory(ctx key := fmt.Sprintf("%s:%s", model, string(provider)) if value, exists := gs.modelConfigs.Load(key); exists && value != nil { if mc, ok := value.(*configstoreTables.TableModelConfig); ok && mc != nil && mc.RateLimitID != nil { - updateRateLimit(*mc.RateLimitID) + if err := gs.BumpRateLimitUsage(ctx, *mc.RateLimitID, tokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil { + return err + } } } } @@ -1153,55 +1296,24 @@ func (gs *LocalGovernanceStore) UpdateProviderAndModelRateLimitUsageInMemory(ctx // Always check model-only config (if exists) - regardless of whether model+provider config exists // Uses findModelOnlyConfig for cross-provider model name normalization if mc, _ := gs.findModelOnlyConfig(ctx, model); mc != nil && mc.RateLimitID != nil { - updateRateLimit(*mc.RateLimitID) + if err := gs.BumpRateLimitUsage(ctx, *mc.RateLimitID, tokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil { + return err + } } return nil } -// UpdateVirtualKeyRateLimitUsageInMemory updates rate limit counters for VK-level rate limits (lock-free) +// UpdateVirtualKeyRateLimitUsageInMemory updates rate limit counters for VK-level rate limits. func (gs *LocalGovernanceStore) UpdateVirtualKeyRateLimitUsageInMemory(ctx context.Context, vk *configstoreTables.TableVirtualKey, provider schemas.ModelProvider, tokensUsed int64, shouldUpdateTokens bool, shouldUpdateRequests bool) error { if vk == nil { return fmt.Errorf("virtual key cannot be nil") } // Collect rate limit IDs using fast in-memory lookup instead of DB queries rateLimitIDs := gs.collectRateLimitIDsFromMemory(ctx, vk, provider) - now := time.Now() for _, rateLimitID := range rateLimitIDs { - // Update in-memory cache for next read (lock-free) - if cachedRateLimitValue, exists := gs.rateLimits.Load(rateLimitID); exists && cachedRateLimitValue != nil { - if cachedRateLimit, ok := cachedRateLimitValue.(*configstoreTables.TableRateLimit); ok && cachedRateLimit != nil { - // Clone FIRST to avoid race conditions - clone := *cachedRateLimit - - // Check if rate limit needs reset (in-memory check) - operate on clone - if clone.TokenResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*clone.TokenResetDuration); err == nil { - if now.Sub(clone.TokenLastReset) >= duration { - clone.TokenCurrentUsage = 0 - clone.TokenLastReset = now - gs.logger.Debug("UpdateRateLimitUsage: Rate limit %s was reset (expired, duration: %v)", rateLimitID, duration) - } - } - } - if clone.RequestResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*clone.RequestResetDuration); err == nil { - if now.Sub(clone.RequestLastReset) >= duration { - clone.RequestCurrentUsage = 0 - clone.RequestLastReset = now - gs.logger.Debug("UpdateRateLimitUsage: Rate limit %s was reset (expired, duration: %v)", rateLimitID, duration) - } - } - } - // Update the clone - if shouldUpdateTokens { - clone.TokenCurrentUsage += tokensUsed - } - if shouldUpdateRequests { - clone.RequestCurrentUsage += 1 - } - gs.rateLimits.Store(rateLimitID, &clone) - } + if err := gs.BumpRateLimitUsage(ctx, rateLimitID, tokensUsed, shouldUpdateTokens, shouldUpdateRequests); err != nil { + return err } } return nil @@ -1213,151 +1325,122 @@ func (gs *LocalGovernanceStore) UpdateUserRateLimitUsageInMemory(ctx context.Con return nil } -// ResetExpiredBudgetsInMemory checks and resets budgets that have exceeded their reset duration (lock-free) +// ResetExpiredBudgetsInMemory checks and resets budgets that have exceeded their reset duration. +// Decision of whether to reset is computed per-budget from the snapshot observed via Range; the +// actual CAS is delegated to ResetBudgetAt, which skips already-reset snapshots and never drops +// a concurrent usage increment. func (gs *LocalGovernanceStore) ResetExpiredBudgetsInMemory(ctx context.Context) []*configstoreTables.TableBudget { now := time.Now() var resetBudgets []*configstoreTables.TableBudget - // We reset all budgets gs.budgets.Range(func(key, value any) bool { - // Type-safe conversion budget, ok := value.(*configstoreTables.TableBudget) if !ok || budget == nil { - return true // continue + return true } - // Determine whether the budget needs resetting var shouldReset bool var newLastReset time.Time - // Any budget and rate limit can be calendar aligned if budget.CalendarAligned { - // Calendar-aligned: reset when we've entered a genuinely new calendar period. currentPeriodStart := configstoreTables.GetCalendarPeriodStart(budget.ResetDuration, now) if currentPeriodStart.After(budget.LastReset) { shouldReset = true newLastReset = currentPeriodStart } } else { - // Rolling duration: reset after the configured duration has elapsed duration, err := configstoreTables.ParseDuration(budget.ResetDuration) if err != nil { gs.logger.Error("invalid budget reset duration %s: %v", budget.ResetDuration, err) - return true // continue + return true } if now.Sub(budget.LastReset) >= duration { shouldReset = true newLastReset = now } } - if shouldReset { - // Create a copy to avoid data race (sync.Map is concurrent-safe for reads/writes but not mutations) - copiedBudget := *budget - oldUsage := copiedBudget.CurrentUsage - copiedBudget.CurrentUsage = 0 - copiedBudget.LastReset = newLastReset - gs.LastDBUsagesBudgetsMu.Lock() - gs.LastDBUsagesBudgets[copiedBudget.ID] = 0 - gs.LastDBUsagesBudgetsMu.Unlock() - // Atomically replace the entry using the original key - gs.budgets.Store(key, &copiedBudget) - resetBudgets = append(resetBudgets, &copiedBudget) - // Update all VKs, teams, customers, and provider configs that reference this budget - gs.updateBudgetReferences(ctx, &copiedBudget) - gs.logger.Debug(fmt.Sprintf("Reset budget %s (was %.2f, reset to 0)", - copiedBudget.ID, oldUsage)) + if !shouldReset { + return true } - return true // continue + resetBudget, ok := gs.ResetBudgetAt(ctx, budget.ID, newLastReset) + if !ok { + // Another resetter got there first, or a concurrent usage update + // already advanced LastReset past ours; nothing to do. + return true + } + oldUsage := budget.CurrentUsage + gs.LastDBUsagesBudgetsMu.Lock() + gs.LastDBUsagesBudgets[resetBudget.ID] = 0 + gs.LastDBUsagesBudgetsMu.Unlock() + resetBudgets = append(resetBudgets, resetBudget) + gs.updateBudgetReferences(ctx, resetBudget) + gs.logger.Debug(fmt.Sprintf("Reset budget %s (was %.2f, reset to 0)", + resetBudget.ID, oldUsage)) + return true }) return resetBudgets } -// ResetExpiredRateLimitsInMemory performs background reset of expired rate limits for both provider-level and VK-level (lock-free) +// ResetExpiredRateLimitsInMemory performs background reset of expired rate limits for both provider-level and VK-level. +// Decision of whether each counter needs resetting is computed per-rate-limit from the snapshot observed via Range; +// the actual CAS is delegated to ResetRateLimitAt, which skips already-reset snapshots and never drops a concurrent +// increment. func (gs *LocalGovernanceStore) ResetExpiredRateLimitsInMemory(ctx context.Context) []*configstoreTables.TableRateLimit { now := time.Now() var resetRateLimits []*configstoreTables.TableRateLimit + // resolvePeriodStart returns the next LastReset target for a counter whose + // reset-duration setting is resetDuration and whose current LastReset is + // lastReset. Returns nil when no reset is due (or the duration is invalid). + resolvePeriodStart := func(resetDuration *string, calendarAligned bool, lastReset time.Time) *time.Time { + if resetDuration == nil { + return nil + } + if calendarAligned { + period := configstoreTables.GetCalendarPeriodStart(*resetDuration, now) + if period.After(lastReset) { + return &period + } + return nil + } + duration, err := configstoreTables.ParseDuration(*resetDuration) + if err != nil { + gs.logger.Error("invalid rate limit reset duration %s: %v", *resetDuration, err) + return nil + } + if now.Sub(lastReset) >= duration { + t := now + return &t + } + return nil + } gs.rateLimits.Range(func(key, value any) bool { - // Type-safe conversion rateLimit, ok := value.(*configstoreTables.TableRateLimit) if !ok || rateLimit == nil { - return true // continue + return true } - tokenNeedsReset := false - requestNeedsReset := false - // Any budget and rate limit can be calendar aligned - if rateLimit.CalendarAligned { - // Calendar-aligned: reset when we've entered a genuinely new calendar period. - if rateLimit.TokenResetDuration != nil { - currentPeriodStart := configstoreTables.GetCalendarPeriodStart(*rateLimit.TokenResetDuration, now) - if currentPeriodStart.After(rateLimit.TokenLastReset) { - tokenNeedsReset = true - } - } - if rateLimit.RequestResetDuration != nil { - currentPeriodStart := configstoreTables.GetCalendarPeriodStart(*rateLimit.RequestResetDuration, now) - if currentPeriodStart.After(rateLimit.RequestLastReset) { - requestNeedsReset = true - } - } - } else { - // Rolling duration: reset after the configured duration has elapsed - if rateLimit.TokenResetDuration != nil { - duration, err := configstoreTables.ParseDuration(*rateLimit.TokenResetDuration) - if err != nil { - gs.logger.Error("invalid budget reset duration %s: %v", *rateLimit.TokenResetDuration, err) - return true // continue - } - if now.Sub(rateLimit.TokenLastReset) >= duration { - tokenNeedsReset = true - } - } - if rateLimit.RequestResetDuration != nil { - duration, err := configstoreTables.ParseDuration(*rateLimit.RequestResetDuration) - if err != nil { - gs.logger.Error("invalid budget reset duration %s: %v", *rateLimit.RequestResetDuration, err) - return true // continue - } - if now.Sub(rateLimit.RequestLastReset) >= duration { - requestNeedsReset = true - } - } - } - // Create a copy to avoid data race (sync.Map is concurrent-safe for reads/writes but not mutations) - copiedRateLimit := *rateLimit - // Reset token limits if expired - if tokenNeedsReset && copiedRateLimit.TokenResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*copiedRateLimit.TokenResetDuration); err == nil { - if now.Sub(copiedRateLimit.TokenLastReset) >= duration { - copiedRateLimit.TokenCurrentUsage = 0 - copiedRateLimit.TokenLastReset = now - gs.LastDBUsagesRateLimitsTokensMu.Lock() - gs.LastDBUsagesTokensRateLimits[copiedRateLimit.ID] = 0 - gs.LastDBUsagesRateLimitsTokensMu.Unlock() - } - } - } - // Reset request limits if expired - if requestNeedsReset && copiedRateLimit.RequestResetDuration != nil { - if duration, err := configstoreTables.ParseDuration(*copiedRateLimit.RequestResetDuration); err == nil { - if now.Sub(copiedRateLimit.RequestLastReset) >= duration { - copiedRateLimit.RequestCurrentUsage = 0 - copiedRateLimit.RequestLastReset = now - gs.LastDBUsagesRateLimitsRequestsMu.Lock() - gs.LastDBUsagesRequestsRateLimits[copiedRateLimit.ID] = 0 - gs.LastDBUsagesRateLimitsRequestsMu.Unlock() - } - } - } - // Only commit the copy + emit for DB reset + rescan references when something - // actually expired. Without this guard the 10-second tick would always call - // gs.updateRateLimitReferences (which scans every VK + provider-config) and - // return every rate limit to the caller for a redundant DB update. Mirrors - // the `if shouldReset { ... }` guard in ResetExpiredBudgetsInMemory above. - if tokenNeedsReset || requestNeedsReset { - // Atomically replace the entry using the original key - gs.rateLimits.Store(key, &copiedRateLimit) - resetRateLimits = append(resetRateLimits, &copiedRateLimit) - // Update all VKs and provider configs that reference this rate limit - gs.updateRateLimitReferences(ctx, &copiedRateLimit) + tokenNewLastReset := resolvePeriodStart(rateLimit.TokenResetDuration, rateLimit.CalendarAligned, rateLimit.TokenLastReset) + requestNewLastReset := resolvePeriodStart(rateLimit.RequestResetDuration, rateLimit.CalendarAligned, rateLimit.RequestLastReset) + if tokenNewLastReset == nil && requestNewLastReset == nil { + return true } - return true // continue + resetRateLimit, ok := gs.ResetRateLimitAt(ctx, rateLimit.ID, tokenNewLastReset, requestNewLastReset) + if !ok { + return true + } + // Clear DB-baseline markers only for the counters we actually reset in + // this call. Baseline locks stay independent of the primary sync.Map + // CAS — they guard a separate map whose values just need consistency, + // not atomicity with the counter mutation. + if tokenNewLastReset != nil { + gs.LastDBUsagesRateLimitsTokensMu.Lock() + gs.LastDBUsagesTokensRateLimits[resetRateLimit.ID] = 0 + gs.LastDBUsagesRateLimitsTokensMu.Unlock() + } + if requestNewLastReset != nil { + gs.LastDBUsagesRateLimitsRequestsMu.Lock() + gs.LastDBUsagesRequestsRateLimits[resetRateLimit.ID] = 0 + gs.LastDBUsagesRateLimitsRequestsMu.Unlock() + } + resetRateLimits = append(resetRateLimits, resetRateLimit) + gs.updateRateLimitReferences(ctx, resetRateLimit) + return true }) return resetRateLimits } diff --git a/plugins/governance/store_concurrency_test.go b/plugins/governance/store_concurrency_test.go new file mode 100644 index 0000000000..67e1f081ff --- /dev/null +++ b/plugins/governance/store_concurrency_test.go @@ -0,0 +1,124 @@ +package governance + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newStandaloneStore builds a LocalGovernanceStore with no config store / +// persistence — just the in-memory maps. Enough for exercising the CAS +// primitives without going through GovernanceConfig preload paths. +func newStandaloneStore(t *testing.T) *LocalGovernanceStore { + t.Helper() + return &LocalGovernanceStore{ + logger: NewMockLogger(), + LastDBUsagesBudgets: map[string]float64{}, + LastDBUsagesTokensRateLimits: map[string]int64{}, + LastDBUsagesRequestsRateLimits: map[string]int64{}, + } +} + +// TestBumpBudgetUsage_NoLostIncrements proves the CAS retry loop in +// BumpBudgetUsage never drops a concurrent increment. Without the CAS, the +// Load→clone→mutate→Store sequence races and the final CurrentUsage ends up +// strictly less than N*cost under contention. +func TestBumpBudgetUsage_NoLostIncrements(t *testing.T) { + store := newStandaloneStore(t) + budgetID := "concurrent-budget" + store.budgets.Store(budgetID, buildBudget(budgetID, 1_000_000_000, "24h")) + + const goroutines = 256 + const perGoroutine = 50 + const cost = 1.0 + expected := float64(goroutines * perGoroutine) + + var wg sync.WaitGroup + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < perGoroutine; j++ { + assert.NoError(t, store.BumpBudgetUsage(context.Background(), budgetID, cost)) + } + }() + } + wg.Wait() + + final := store.LoadBudget(context.Background(), budgetID) + require.NotNil(t, final) + assert.Equal(t, expected, final.CurrentUsage, "CurrentUsage must equal total increments — any shortfall is a dropped write") +} + +// TestBumpRateLimitUsage_NoLostIncrements covers the rate-limit variant of +// the same race: token and request counters are independent int64 fields +// updated on the same struct, and both must survive contention intact. +func TestBumpRateLimitUsage_NoLostIncrements(t *testing.T) { + store := newStandaloneStore(t) + rlID := "concurrent-rate-limit" + store.rateLimits.Store(rlID, buildRateLimit(rlID, 1_000_000_000, 1_000_000_000)) + + const goroutines = 256 + const perGoroutine = 50 + const tokensPerCall = int64(7) + + var wg sync.WaitGroup + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < perGoroutine; j++ { + assert.NoError(t, store.BumpRateLimitUsage(context.Background(), rlID, tokensPerCall, true, true)) + } + }() + } + wg.Wait() + + final := store.LoadRateLimit(context.Background(), rlID) + require.NotNil(t, final) + assert.Equal(t, int64(goroutines*perGoroutine)*tokensPerCall, final.TokenCurrentUsage, "TokenCurrentUsage dropped increments") + assert.Equal(t, int64(goroutines*perGoroutine), final.RequestCurrentUsage, "RequestCurrentUsage dropped increments") +} + +// TestResetBudgetAt_ConcurrentResettersCollapse confirms that many goroutines +// all trying to reset the same budget to the same newLastReset deduplicate +// cleanly via CAS — exactly one resetter observes the transition, everyone +// else gets (nil, false). Without the re-check inside ResetBudgetAt, each +// goroutine would re-zero the counter and drop any increments applied in +// between. +func TestResetBudgetAt_ConcurrentResettersCollapse(t *testing.T) { + store := newStandaloneStore(t) + budgetID := "reset-collapse" + old := buildBudget(budgetID, 1000, "1h") + old.LastReset = time.Now().Add(-2 * time.Hour) + old.CurrentUsage = 999 + store.budgets.Store(budgetID, old) + + const goroutines = 128 + newLastReset := time.Now() + + var successes atomic.Int64 + var wg sync.WaitGroup + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + if _, ok := store.ResetBudgetAt(context.Background(), budgetID, newLastReset); ok { + successes.Add(1) + } + }() + } + wg.Wait() + + assert.Equal(t, int64(1), successes.Load(), "exactly one resetter should win the CAS when all target the same newLastReset") + final := store.LoadBudget(context.Background(), budgetID) + require.NotNil(t, final) + assert.Equal(t, 0.0, final.CurrentUsage) + assert.True(t, final.LastReset.Equal(newLastReset)) +} +