From 8e2fab780ea8b28d1f8710bb9c086b35130385e4 Mon Sep 17 00:00:00 2001 From: Pratham-Mishra04 Date: Wed, 15 Apr 2026 12:58:32 +0530 Subject: [PATCH] fix: send to closed channels edge case race condition on provider reload --- core/bifrost.go | 445 +++++--- core/bifrost_test.go | 997 ++++++++++++++++++ transports/bifrost-http/handlers/providers.go | 2 - 3 files changed, 1312 insertions(+), 132 deletions(-) diff --git a/core/bifrost.go b/core/bifrost.go index 7db6663758..17543d4cbe 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -89,12 +89,41 @@ type Bifrost struct { // ProviderQueue wraps a provider's request channel with lifecycle management // to prevent "send on closed channel" panics during provider removal/update. // Producers must check the closing flag or select on the done channel before sending. +// +// Why pq.queue is NEVER closed: +// +// Closing a channel in Go causes any concurrent send to that channel to panic +// ("send on closed channel"). There is always a TOCTOU window between a +// producer's isClosing() check and its select { case pq.queue <- msg: ... }: +// the producer could pass isClosing() while the queue is open, get preempted, +// and resume only after the queue is closed. Go's selectgo evaluates select +// cases in a random order, so even having case <-pq.done: in the same select +// does not protect against this — if selectgo evaluates the send case first on +// a closed channel it panics immediately via goto sclose, before reaching done. +// +// To close pq.queue safely you would need a sender-side WaitGroup so that +// signalClosing could wait for every in-flight producer to finish. That adds +// non-trivial overhead on the hot request path. +// +// Instead, pq.done is the sole shutdown signal. Receiving from a closed channel +// is always safe (returns the zero value immediately), so: +// - Workers exit via case <-pq.done: — safe +// - Producers bail via case <-pq.done: — safe +// - drainQueueWithErrors handles any messages that slip through the TOCTOU window +// +// pq.queue is garbage collected automatically: +// - RemoveProvider calls requestQueues.Delete, dropping the map's reference. +// - UpdateProvider calls requestQueues.Store with a new queue, dropping the +// map's reference to oldPq. Shutdown does not Delete at all — the whole +// Bifrost instance is torn down. +// In all cases, once no producer goroutine holds a reference to the +// ProviderQueue, both the struct and pq.queue are eligible for GC. +// No explicit close is needed. type ProviderQueue struct { - queue chan *ChannelMessage // the actual request queue channel - done chan struct{} // closed to signal shutdown to producers + queue chan *ChannelMessage // the actual request queue channel — never closed, see above + done chan struct{} // closed by signalClosing() to signal shutdown; never written to otherwise closing uint32 // atomic: 0 = open, 1 = closing signalOnce sync.Once - closeOnce sync.Once } func isLargePayloadPassthrough(ctx *schemas.BifrostContext) bool { @@ -122,14 +151,6 @@ func (pq *ProviderQueue) signalClosing() { }) } -// closeQueue closes the provider queue. -// Protected by sync.Once to prevent double-close. -func (pq *ProviderQueue) closeQueue() { - pq.closeOnce.Do(func() { - close(pq.queue) - }) -} - // isClosing returns true if the provider queue is closing. // Uses atomic load for lock-free checking. func (pq *ProviderQueue) isClosing() bool { @@ -3109,57 +3130,36 @@ func (bifrost *Bifrost) RemoveProvider(providerKey schemas.ModelProvider) error } pq := pqValue.(*ProviderQueue) - // Step 2: Signal closing to producers (prevents new sends) - // This must happen before closing the queue to avoid "send on closed channel" panics + // Step 2: Signal closing. Blocks new producers (isClosing() returns true) and + // causes idle workers to drain remaining buffered requests with errors then exit. pq.signalClosing() bifrost.logger.Debug("signaled closing for provider %s", providerKey) - // Step 3: Now safe to close the queue (no new producers can send) - pq.closeQueue() - bifrost.logger.Debug("closed request queue for provider %s", providerKey) - - // Step 4: Wait for all workers to finish processing in-flight requests + // Step 3: Wait for all workers to finish in-flight requests and exit. waitGroup, exists := bifrost.waitGroups.Load(providerKey) if exists { waitGroup.(*sync.WaitGroup).Wait() bifrost.logger.Debug("all workers for provider %s have stopped", providerKey) } - // Step 5: Remove the provider from the request queues + // Step 3b: Final drain sweep — see drainQueueWithErrors for full explanation. + bifrost.drainQueueWithErrors(pq) + + // Step 4: Remove the provider from the request queues. bifrost.requestQueues.Delete(providerKey) - // Step 6: Remove the provider from the wait groups + // Step 5: Remove the provider from the wait groups. bifrost.waitGroups.Delete(providerKey) - // Step 7: Remove the provider from the providers slice - replacementAttempts := 0 - maxReplacementAttempts := 100 // Prevent infinite loops in high-contention scenarios - for { - replacementAttempts++ - if replacementAttempts > maxReplacementAttempts { - return fmt.Errorf("failed to replace provider %s in providers slice after %d attempts", providerKey, maxReplacementAttempts) - } - oldPtr := bifrost.providers.Load() - var oldSlice []schemas.Provider - if oldPtr != nil { - oldSlice = *oldPtr - } - // Create new slice without the old provider of this key - // Use exact capacity to avoid allocations - if len(oldSlice) == 0 { - return fmt.Errorf("provider %s not found in providers slice", providerKey) - } - newSlice := make([]schemas.Provider, 0, len(oldSlice)-1) - for _, existingProvider := range oldSlice { - if existingProvider.GetProviderKey() != providerKey { - newSlice = append(newSlice, existingProvider) - } - } - if bifrost.providers.CompareAndSwap(oldPtr, &newSlice) { - bifrost.logger.Debug("successfully removed provider instance for %s in providers slice", providerKey) - break - } - // Retrying as swapping did not work (likely due to concurrent modification) + // Step 6: Remove the provider from the providers slice. + if err := bifrost.removeProviderFromSlice(providerKey); err != nil { + bifrost.logger.Error( + "provider %s was removed from queues but could not be removed from the providers slice — "+ + "bifrost.providers is now inconsistent. "+ + "To recover: retry RemoveProvider(%s), or restart Bifrost if that fails.", + providerKey, providerKey, + ) + return err } bifrost.logger.Info("successfully removed provider %s", providerKey) @@ -3181,6 +3181,15 @@ func (bifrost *Bifrost) RemoveProvider(providerKey schemas.ModelProvider) error // Note: This operation will temporarily pause request processing for the specified provider // while the transition occurs. In-flight requests will complete before workers are stopped. // Buffered requests in the old queue will be transferred to the new queue to prevent loss. +// +// Concurrency safety — no-worker window: +// UpdateProvider holds a per-provider write lock (providerMutex.Lock) for its entire +// duration. All producer paths (tryRequest, tryStreamRequest) acquire the corresponding +// read lock inside getProviderQueue before they can look up or enqueue into any queue. +// This means no producer can observe or enqueue into newPq until UpdateProvider returns +// and releases the write lock — at which point new workers are already running and +// consuming newPq. There is therefore no window where newPq is visible to producers +// but has zero workers. func (bifrost *Bifrost) UpdateProvider(providerKey schemas.ModelProvider) error { bifrost.logger.Info(fmt.Sprintf("Updating provider configuration for provider %s", providerKey)) // Get the updated configuration from the account @@ -3213,23 +3222,23 @@ func (bifrost *Bifrost) UpdateProvider(providerKey schemas.ModelProvider) error queue: make(chan *ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize), done: make(chan struct{}), signalOnce: sync.Once{}, - closeOnce: sync.Once{}, } - // Step 2: Atomically replace the queue FIRST (new producers immediately get the new queue) - // This minimizes the window where requests fail during the update + // Step 2: Atomically replace the queue so new producers immediately use newPq. bifrost.requestQueues.Store(providerKey, newPq) bifrost.logger.Debug("stored new queue for provider %s, new producers will use it", providerKey) - // Step 3: Signal old queue is closing to producers that already have a reference - // Only in-flight producers with the old reference will see this - oldPq.signalClosing() - bifrost.logger.Debug("signaled closing for old queue of provider %s", providerKey) - - // Step 4: Transfer any buffered requests from old queue to new queue - // This prevents request loss during the transition + // Step 3: Transfer buffered requests from the old queue to the new queue BEFORE + // signalling workers to stop. This ensures buffered requests are processed by the + // new workers rather than being drained with errors. + // Old workers are still running and may consume some items concurrently — that is + // fine, they process them normally. + // If newPq is full during transfer, all remaining buffered requests are cancelled + // immediately rather than blocking — this avoids the deadlock where transfer goroutines + // wait for space that only opens once new workers start (which can't happen until + // the transfer completes). transferredCount := 0 - var transferWaitGroup sync.WaitGroup + cancelledCount := 0 for { select { case msg := <-oldPq.queue: @@ -3237,37 +3246,33 @@ func (bifrost *Bifrost) UpdateProvider(providerKey schemas.ModelProvider) error case newPq.queue <- msg: transferredCount++ default: - // New queue is full, handle this request in a goroutine - // This is unlikely with proper buffer sizing but provides safety - transferWaitGroup.Add(1) - go func(m *ChannelMessage) { - defer transferWaitGroup.Done() + // newPq is full — cancel this message and all remaining in oldPq. + cancelMsg := func(r *ChannelMessage) { + prov, mod, _ := r.BifrostRequest.GetRequestFields() select { - case newPq.queue <- m: - // Message successfully transferred - case <-time.After(5 * time.Second): - bifrost.logger.Warn("Failed to transfer buffered request to new queue within timeout") - // Send error response to avoid hanging the client - provider, model, _ := m.BifrostRequest.GetRequestFields() - select { - case m.Err <- schemas.BifrostError{ - IsBifrostError: false, - Error: &schemas.ErrorField{ - Message: "request failed during provider concurrency update", - }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: m.RequestType, - Provider: provider, - ModelRequested: model, - }, - }: - case <-time.After(1 * time.Second): - // If we can't send the error either, just log and continue - bifrost.logger.Warn("Failed to send error response during transfer timeout") - } + case r.Err <- schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{Message: "request failed during provider concurrency update: queue full"}, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: r.RequestType, + Provider: prov, + ModelRequested: mod, + }, + }: + case <-r.Context.Done(): } - }(msg) - goto transferComplete + } + cancelMsg(msg) + cancelledCount++ + for { + select { + case r := <-oldPq.queue: + cancelMsg(r) + cancelledCount++ + default: + goto transferComplete + } + } } default: // No more buffered messages @@ -3276,33 +3281,59 @@ func (bifrost *Bifrost) UpdateProvider(providerKey schemas.ModelProvider) error } transferComplete: - // Wait for all transfer goroutines to complete - transferWaitGroup.Wait() if transferredCount > 0 { bifrost.logger.Info("transferred %d buffered requests to new queue for provider %s", transferredCount, providerKey) } + if cancelledCount > 0 { + bifrost.logger.Warn("cancelled %d buffered requests during transfer for provider %s: new queue was full", cancelledCount, providerKey) + } - // Step 5: Close the old queue to signal workers to stop - oldPq.closeQueue() - bifrost.logger.Debug("closed old request queue for provider %s", providerKey) + // Step 4: Signal the old queue is closing. Producers that still hold a reference to + // oldPq will detect this via isClosing() and transparently re-route to newPq. + // This happens after the transfer so the new queue is already populated before + // stale producers attempt their re-route. + oldPq.signalClosing() + bifrost.logger.Debug("signaled closing for old queue of provider %s", providerKey) - // Step 6: Wait for all existing workers to finish processing in-flight requests + // Step 5: Wait for all existing workers to finish processing in-flight requests. + // Workers exit via oldPq.done (signalled above). waitGroup, exists := bifrost.waitGroups.Load(providerKey) if exists { waitGroup.(*sync.WaitGroup).Wait() bifrost.logger.Debug("all workers for provider %s have stopped", providerKey) } - // Step 7: Create new wait group for the updated workers + // Step 5b: Final drain sweep — see drainQueueWithErrors for full explanation. + bifrost.drainQueueWithErrors(oldPq) + + // Step 6: Create new wait group for the updated workers. bifrost.waitGroups.Store(providerKey, &sync.WaitGroup{}) - // Step 8: Create provider instance + // Step 7: Create provider instance. provider, err := bifrost.createBaseProvider(providerKey, providerConfig) if err != nil { - return fmt.Errorf("failed to create provider instance for %s: %v", providerKey, err) - } - - // Step 8.5: Atomically replace the provider in the providers slice + // Roll back: signal closing, remove from map, then drain. + // Order matters: Delete before drainQueueWithErrors so that producers + // re-routing via requestQueues.Load find nothing and return "provider + // shutting down" immediately, narrowing the TOCTOU window before the sweep. + newPq.signalClosing() + bifrost.requestQueues.Delete(providerKey) + bifrost.waitGroups.Delete(providerKey) + bifrost.drainQueueWithErrors(newPq) + if sliceErr := bifrost.removeProviderFromSlice(providerKey); sliceErr != nil { + bifrost.logger.Error( + "UpdateProvider rollback for %s is incomplete — provider was removed from queues "+ + "but could not be removed from the providers slice: %v. "+ + "bifrost.providers is now inconsistent. "+ + "To recover: call RemoveProvider(%s) then AddProvider to re-register it, "+ + "or restart Bifrost if that fails.", + providerKey, sliceErr, providerKey, + ) + } + return fmt.Errorf("provider update for %s failed during initialization; provider has been removed — re-add or retry UpdateProvider to restore it: %v", providerKey, err) + } + + // Step 8: Atomically replace the provider in the providers slice. // This must happen before starting new workers to prevent stale reads bifrost.logger.Debug("atomically replacing provider instance in providers slice for %s", providerKey) @@ -3312,7 +3343,21 @@ transferComplete: for { replacementAttempts++ if replacementAttempts > maxReplacementAttempts { - return fmt.Errorf("failed to replace provider %s in providers slice after %d attempts", providerKey, maxReplacementAttempts) + newPq.signalClosing() + bifrost.requestQueues.Delete(providerKey) + bifrost.waitGroups.Delete(providerKey) + bifrost.drainQueueWithErrors(newPq) + if sliceErr := bifrost.removeProviderFromSlice(providerKey); sliceErr != nil { + bifrost.logger.Error( + "UpdateProvider rollback for %s is incomplete — provider was removed from queues "+ + "but could not be removed from the providers slice: %v. "+ + "bifrost.providers is now inconsistent. "+ + "To recover: call RemoveProvider(%s) then AddProvider to re-register it, "+ + "or restart Bifrost if that fails.", + providerKey, sliceErr, providerKey, + ) + } + return fmt.Errorf("failed to replace provider %s in providers slice after %d attempts; provider has been removed — re-add or retry UpdateProvider to restore it", providerKey, maxReplacementAttempts) } oldPtr := bifrost.providers.Load() @@ -3348,7 +3393,7 @@ transferComplete: // Retrying as swapping did not work (likely due to concurrent modification) } - // Step 9: Start new workers with updated concurrency + // Step 9: Start new workers with updated concurrency. bifrost.logger.Debug("starting %d new workers for provider %s with buffer size %d", providerConfig.ConcurrencyAndBufferSize.Concurrency, providerKey, @@ -3384,6 +3429,33 @@ func (bifrost *Bifrost) getProviderMutex(providerKey schemas.ModelProvider) *syn return mutexValue.(*sync.RWMutex) } +// removeProviderFromSlice atomically removes the provider with the given key +// from bifrost.providers using a CAS retry loop. Callers hold the per-provider +// write mutex so no concurrent goroutine can re-add this key — contention is +// only from other providers' CAS operations, so the loop converges in at most +// a few iterations under any concurrency level. +// Returns an error if the limit is hit (state will be inconsistent). +func (bifrost *Bifrost) removeProviderFromSlice(providerKey schemas.ModelProvider) error { + const maxAttempts = 100 + for range maxAttempts { + oldPtr := bifrost.providers.Load() + if oldPtr == nil { + return nil + } + oldSlice := *oldPtr + newSlice := make([]schemas.Provider, 0, len(oldSlice)) + for _, p := range oldSlice { + if p.GetProviderKey() != providerKey { + newSlice = append(newSlice, p) + } + } + if bifrost.providers.CompareAndSwap(oldPtr, &newSlice) { + return nil + } + } + return fmt.Errorf("failed to remove provider %s from providers slice after %d attempts", providerKey, maxAttempts) +} + // MCP PUBLIC API // RegisterMCPTool registers a typed tool handler with the MCP integration. @@ -3694,7 +3766,6 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi queue: make(chan *ChannelMessage, config.ConcurrencyAndBufferSize.BufferSize), done: make(chan struct{}), signalOnce: sync.Once{}, - closeOnce: sync.Once{}, } bifrost.requestQueues.Store(providerKey, pq) @@ -4382,17 +4453,31 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif msg := bifrost.getChannelMessage(*preReq) msg.Context = ctx - // Check if provider is closing before attempting to send (lock-free atomic check) - // This prevents "send on closed channel" panics during provider removal/update + // If the queue is closing, check whether the provider was updated (new queue + // available) or removed. On update, transparently re-route to the new queue + // so in-flight producers don't get spurious errors. On removal, error out. + // + // Use a direct sync.Map lookup instead of getProviderQueue to avoid the + // lazy-creation path: getProviderQueue can resurrect a provider that was + // just removed by RemoveProvider if the account config still exists. if pq.isClosing() { - bifrost.releaseChannelMessage(msg) - bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, + var reroutedPq *ProviderQueue + if val, ok := bifrost.requestQueues.Load(provider); ok { + if candidate := val.(*ProviderQueue); candidate != pq && !candidate.isClosing() { + reroutedPq = candidate + } } - return nil, bifrostErr + if reroutedPq == nil { + bifrost.releaseChannelMessage(msg) + bifrostErr := newBifrostErrorFromMsg("provider is shutting down") + bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ + RequestType: req.RequestType, + Provider: provider, + ModelRequested: model, + } + return nil, bifrostErr + } + pq = reroutedPq } // Use select with done channel to detect shutdown during send @@ -4492,7 +4577,13 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif } return resp, nil case <-ctx.Done(): - bifrost.releaseChannelMessage(msg) + // Do NOT releaseChannelMessage here. The message is already enqueued and + // the worker still holds a reference to msg.Response and msg.Err. Returning + // those channels to the pool now would let the next request reuse them while + // the worker is still writing to them — stale data corruption. The worker + // never calls releaseChannelMessage itself, so this message leaks from the + // pool and is GC'd. That is intentional: a small pool leak on cancellation + // is far safer than corrupting another request's channels. provider, model, _ := req.GetRequestFields() return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "waiting for provider response") } @@ -4629,17 +4720,31 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem msg := bifrost.getChannelMessage(*preReq) msg.Context = ctx - // Check if provider is closing before attempting to send (lock-free atomic check) - // This prevents "send on closed channel" panics during provider removal/update + // If the queue is closing, check whether the provider was updated (new queue + // available) or removed. On update, transparently re-route to the new queue + // so in-flight producers don't get spurious errors. On removal, error out. + // + // Use a direct sync.Map lookup instead of getProviderQueue to avoid the + // lazy-creation path: getProviderQueue can resurrect a provider that was + // just removed by RemoveProvider if the account config still exists. if pq.isClosing() { - bifrost.releaseChannelMessage(msg) - bifrostErr := newBifrostErrorFromMsg("provider is shutting down") - bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ - RequestType: req.RequestType, - Provider: provider, - ModelRequested: model, + var reroutedPq *ProviderQueue + if val, ok := bifrost.requestQueues.Load(provider); ok { + if candidate := val.(*ProviderQueue); candidate != pq && !candidate.isClosing() { + reroutedPq = candidate + } } - return nil, bifrostErr + if reroutedPq == nil { + bifrost.releaseChannelMessage(msg) + bifrostErr := newBifrostErrorFromMsg("provider is shutting down") + bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ + RequestType: req.RequestType, + Provider: provider, + ModelRequested: model, + } + return nil, bifrostErr + } + pq = reroutedPq } // Use select with done channel to detect shutdown during send @@ -4721,6 +4826,11 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem return newBifrostMessageChan(recoveredResp), nil } return nil, &bifrostErrVal + case <-ctx.Done(): + // Do NOT releaseChannelMessage here — see the identical note in tryRequest. + // Worker still holds msg.ResponseStream/msg.Err; releasing now corrupts the + // next request that reuses those pooled channels. + return nil, newBifrostCtxDoneError(ctx, provider, model, req.RequestType, "while waiting for stream response") } } @@ -4937,7 +5047,38 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } }() - for req := range pq.queue { + for { + var req *ChannelMessage + select { + case r := <-pq.queue: + req = r + case <-pq.done: + // Provider is shutting down. Drain any buffered requests and send + // back errors so callers are not left blocked on their response channel. + for { + select { + case r := <-pq.queue: + provKey, mod, _ := r.GetRequestFields() + select { + case r.Err <- schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "provider is shutting down", + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: r.RequestType, + Provider: provKey, + ModelRequested: mod, + }, + }: + case <-r.Context.Done(): + } + default: + return + } + } + } + _, model, _ := req.BifrostRequest.GetRequestFields() var result *schemas.BifrostResponse @@ -5984,6 +6125,47 @@ func (bifrost *Bifrost) getChannelMessage(req schemas.BifrostRequest) *ChannelMe return msg } +// drainQueueWithErrors drains all buffered messages from pq and sends each a +// "provider is shutting down" error. It must be called after all workers for +// the queue have exited (i.e. after wg.Wait()) to cover the TOCTOU window: +// a producer that passed isClosing() just before signalClosing fired can still +// win the `case pq.queue <- msg` branch in tryRequest, landing a message in +// the queue after the last worker's drain loop already exited via `default:`. +// Without this sweep, those callers block forever on <-msg.Response / <-msg.Err. +// +// Residual TOCTOU window (known limitation): this sweep runs exactly once via +// a non-blocking `select { default: }`. A producer that deposits a message +// after the sweep's `default:` branch exits has no worker and no sweep to drain +// it — the caller will block until its own context is cancelled. Fully closing +// this window requires a sender-side reference count (so the last producer can +// signal "queue is fully idle"), which is intentionally not implemented because +// it would add per-send atomic overhead on the hot path. +func (bifrost *Bifrost) drainQueueWithErrors(pq *ProviderQueue) { + for { + select { + case r := <-pq.queue: + provKey, mod, _ := r.GetRequestFields() + select { + case r.Err <- schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{Message: "provider is shutting down"}, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: r.RequestType, + Provider: provKey, + ModelRequested: mod, + }, + }: + case <-r.Context.Done(): + // No time.After needed: r.Err is a buffered channel of size 1 freshly + // allocated per request, so the send always completes immediately unless + // the caller already cancelled. ctx.Done() is the only valid escape. + } + default: + return + } + } +} + // releaseChannelMessage returns a ChannelMessage and its channels to their respective pools. func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) { // Put channels back in pools @@ -6491,15 +6673,12 @@ func (bifrost *Bifrost) Shutdown() { if bifrost.ctx.Err() == nil && bifrost.cancel != nil { bifrost.cancel() } - // ALWAYS close all provider queues to signal workers to stop, - // even if context was already cancelled. This prevents goroutine leaks. - // Use the ProviderQueue lifecycle: signal closing, then close the queue + // Signal all provider queues to close. Workers exit via pq.done; + // we never close pq.queue to avoid "send on closed channel" panics in + // producers that are concurrently in tryRequest. bifrost.requestQueues.Range(func(key, value interface{}) bool { pq := value.(*ProviderQueue) - // Signal closing to producers (uses sync.Once internally) pq.signalClosing() - // Close the queue to signal workers (uses sync.Once internally) - pq.closeQueue() return true }) @@ -6510,6 +6689,12 @@ func (bifrost *Bifrost) Shutdown() { return true }) + // Final drain sweep — same reasoning as RemoveProvider's Step 3b. + bifrost.requestQueues.Range(func(key, value interface{}) bool { + bifrost.drainQueueWithErrors(value.(*ProviderQueue)) + return true + }) + // Cleanup MCP manager if bifrost.MCPManager != nil { err := bifrost.MCPManager.Cleanup() diff --git a/core/bifrost_test.go b/core/bifrost_test.go index cb22f5e359..6944ed1d9d 100644 --- a/core/bifrost_test.go +++ b/core/bifrost_test.go @@ -3,8 +3,10 @@ package bifrost import ( "context" "fmt" + "runtime" "strings" "sync" + "sync/atomic" "testing" "time" @@ -1300,3 +1302,998 @@ func TestUpdateProvider_ProviderSliceIntegrity(t *testing.T) { } }) } + +// TestProviderQueue_SendOnClosedChannel_Race demonstrates the TOCTOU race that +// caused the "send on closed channel" production panic in the OLD code. +// +// The old code called close(pq.queue) during provider shutdown. The sequence: +// 1. Producer calls isClosing() → false (queue is still open) +// 2. Concurrently: shutdown calls signalClosing() then close(pq.queue) +// 3. Producer enters select { case pq.queue <- msg: ... case <-pq.done: ... } +// → PANIC: Go's selectgo iterates cases in a randomised pollorder. When the +// closed-channel send case is checked first, it immediately panics via +// goto sclose — before it can reach the done case. +// The case <-pq.done: guard only saves you when done happens to be checked +// first in that random ordering (≈50 % of the time with two cases). +// +// THE FIX: pq.queue is never closed. See the ProviderQueue struct comment for +// the full explanation. This test is kept as a proof-of-concept showing why +// closing pq.queue is unsafe; the fix is validated by TestProviderQueue_NoPanicWithoutCloseQueue. +// +// We run many iterations so that the panic is statistically certain to surface +// at least once, confirming the hypothesis. +func TestProviderQueue_SendOnClosedChannel_Race(t *testing.T) { + // With two select cases each iteration has a ~50 % chance of panicking. + // The probability of never panicking in 200 iterations is (0.5)^200 ≈ 0. + const iterations = 200 + panicCount := 0 + + for i := 0; i < iterations; i++ { + func() { + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + // Synchronization barriers to force the exact race interleaving. + passedIsClosingCheck := make(chan struct{}) + queueClosed := make(chan struct{}) + + var panicked bool + var wg sync.WaitGroup + wg.Add(1) + + // Producer — mirrors the hot path in tryRequest. + go func() { + defer wg.Done() + defer func() { + if r := recover(); r != nil && fmt.Sprint(r) == "send on closed channel" { + panicked = true + } + }() + + // Step 1: isClosing() passes — queue is open. + if pq.isClosing() { + return + } + + // Signal: past the isClosing() gate. + close(passedIsClosingCheck) + + // Wait for the queue to be closed. This represents the real work + // tryRequest does between the isClosing() check and the select + // (MCP setup, tracer lookup, plugin pipeline acquisition). + <-queueClosed + + // Step 2: enter the exact select guard used in production. + // pq.queue is closed AND pq.done is closed. + // When selectgo picks the send case first in its random pollorder + // it hits goto sclose and panics — the done case cannot save it. + msg := &ChannelMessage{} + select { + case pq.queue <- msg: // panics ~50 % of iterations + case <-pq.done: // selected the other ~50 % + } + }() + + // Closer — mirrors UpdateProvider / RemoveProvider. + go func() { + <-passedIsClosingCheck + pq.signalClosing() // closes done, sets closing = 1 + close(pq.queue) + close(queueClosed) // release the producer into the select + }() + + wg.Wait() + if panicked { + panicCount++ + } + }() + } + + if panicCount == 0 { + t.Fatalf("expected at least one 'send on closed channel' panic across %d iterations, got none", iterations) + } + t.Logf("confirmed: panic triggered in %d / %d iterations — hypothesis is correct", panicCount, iterations) +} + +// ============================================================================= +// ProviderQueue Unit Tests +// +// These tests exercise the ProviderQueue lifecycle in isolation — no full +// Bifrost instance required. They validate the core safety invariants that +// prevent the "send on closed channel" panic. +// ============================================================================= + +// newTestChannelMessage creates a minimal ChannelMessage suitable for drain tests. +// The Err channel is buffered (size 1) so the worker can send without blocking. +func newTestChannelMessage(ctx *schemas.BifrostContext) *ChannelMessage { + return &ChannelMessage{ + BifrostRequest: schemas.BifrostRequest{ + RequestType: schemas.ChatCompletionRequest, + ChatRequest: &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4", + }, + }, + Context: ctx, + Response: make(chan *schemas.BifrostResponse, 1), + Err: make(chan schemas.BifrostError, 1), + } +} + +// TestProviderQueue_IsClosingStateTransition verifies the atomic state flag: +// isClosing() must return false before signalClosing() and true after. +func TestProviderQueue_IsClosingStateTransition(t *testing.T) { + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + if pq.isClosing() { + t.Fatal("isClosing() must be false before signalClosing() is called") + } + + pq.signalClosing() + + if !pq.isClosing() { + t.Fatal("isClosing() must be true after signalClosing() is called") + } + + // done channel must also be closed + select { + case <-pq.done: + // correct: done is closed + default: + t.Fatal("pq.done must be closed after signalClosing()") + } + + // queue channel must remain OPEN — this is the core of the fix + // (sending should not panic even though done is closed) + panicked := false + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + select { + case pq.queue <- &ChannelMessage{}: + case <-pq.done: // done is closed so this is always ready — no panic + } + }() + if panicked { + t.Fatal("queue channel must stay open after signalClosing() — sending to it must not panic") + } +} + +// TestProviderQueue_SignalOnceIdempotent verifies that calling signalClosing() +// multiple times is safe. sync.Once ensures done is only closed once and the +// atomic store only happens once — no "close of closed channel" panic. +func TestProviderQueue_SignalOnceIdempotent(t *testing.T) { + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + defer func() { + if r := recover(); r != nil { + t.Fatalf("unexpected panic from multiple signalClosing() calls: %v", r) + } + }() + + pq.signalClosing() + pq.signalClosing() + pq.signalClosing() + + if !pq.isClosing() { + t.Fatal("isClosing() must be true after multiple signalClosing() calls") + } +} + +// TestProviderQueue_WorkerExitsViaDone verifies that a worker running the +// fixed select loop exits cleanly after signalClosing() without closeQueue(). +// Before the fix, workers used `for req := range pq.queue` which required +// the channel to be closed. After the fix, done is the exit signal. +func TestProviderQueue_WorkerExitsViaDone(t *testing.T) { + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + workerExited := make(chan struct{}) + + // Minimal worker loop — mirrors the exact select pattern in requestWorker + go func() { + defer close(workerExited) + for { + select { + case r, ok := <-pq.queue: + if !ok { + return + } + _ = r // process (no-op in this test) + case <-pq.done: + // Drain remaining buffered items (queue is empty here) + for { + select { + case <-pq.queue: + default: + return + } + } + } + } + }() + + // Worker is now blocked on the select. Signal shutdown WITHOUT closing queue. + pq.signalClosing() + + select { + case <-workerExited: + // correct: worker exited via done + case <-time.After(2 * time.Second): + t.Fatal("worker did not exit after signalClosing() — it may be stuck on range over unclosed channel") + } +} + +// TestProviderQueue_WorkerDrainSendsErrors verifies the drain behaviour when +// done fires while items are still buffered: every buffered ChannelMessage must +// receive a "provider is shutting down" error on its Err channel. No client +// should be left blocked waiting for a response that will never come. +// +// This test exercises the drain path directly — same code as requestWorker's +// case <-pq.done: branch — to avoid a non-deterministic select race between the +// normal processing path and the done path. +func TestProviderQueue_WorkerDrainSendsErrors(t *testing.T) { + const numBuffered = 5 + + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, numBuffered+2), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + + // Pre-fill queue — simulates requests buffered when done fires + msgs := make([]*ChannelMessage, numBuffered) + for i := 0; i < numBuffered; i++ { + msgs[i] = newTestChannelMessage(ctx) + pq.queue <- msgs[i] + } + + // Signal closing: done is now closed + pq.signalClosing() + + // Execute the drain path synchronously — exactly what requestWorker does in + // the case <-pq.done: branch. This is deterministic: we know done is closed + // and the queue has numBuffered items. + <-pq.done // fires immediately since signalClosing was already called +drainLoop: + for { + select { + case r := <-pq.queue: + provKey, mod, _ := r.GetRequestFields() + r.Err <- schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "provider is shutting down", + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: r.RequestType, + Provider: provKey, + ModelRequested: mod, + }, + } + default: + break drainLoop + } + } + + // Verify every message received a shutdown error + for i, msg := range msgs { + select { + case bifrostErr := <-msg.Err: + if bifrostErr.Error == nil { + t.Errorf("message %d: received nil Error field", i) + continue + } + if bifrostErr.Error.Message != "provider is shutting down" { + t.Errorf("message %d: expected 'provider is shutting down', got %q", + i, bifrostErr.Error.Message) + } + if bifrostErr.ExtraFields.Provider != schemas.OpenAI { + t.Errorf("message %d: expected provider %s, got %s", + i, schemas.OpenAI, bifrostErr.ExtraFields.Provider) + } + if bifrostErr.ExtraFields.RequestType != schemas.ChatCompletionRequest { + t.Errorf("message %d: expected requestType %v, got %v", + i, schemas.ChatCompletionRequest, bifrostErr.ExtraFields.RequestType) + } + default: + t.Errorf("message %d: no error received — client would be left hanging indefinitely", i) + } + } +} + +// TestProviderQueue_NoPanicWithoutCloseQueue verifies that the fixed hot path +// — select { case pq.queue <- msg | case <-pq.done } — never panics when +// signalClosing() fires but the queue channel is NOT closed. +// +// This is the direct inverse of TestProviderQueue_SendOnClosedChannel_Race: +// that test proves the old code panics ~50% of the time; this test proves +// the fixed code panics 0% of the time. +func TestProviderQueue_NoPanicWithoutCloseQueue(t *testing.T) { + const iterations = 500 + + for i := 0; i < iterations; i++ { + func() { + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + passedIsClosingCheck := make(chan struct{}) + shutdownDone := make(chan struct{}) + + var panicked bool + var wg sync.WaitGroup + wg.Add(1) + + // Producer: mirrors the tryRequest hot path after the fix. + // Passes isClosing(), waits for signalClosing, then sends. + // The queue channel is NEVER closed — only done is closed. + go func() { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + + if pq.isClosing() { + return + } + close(passedIsClosingCheck) + <-shutdownDone + + msg := &ChannelMessage{} + select { + case pq.queue <- msg: // queue is open → safe to send + case <-pq.done: // done is closed → selected immediately + } + }() + + // Closer: signal shutdown but never close the queue channel + go func() { + <-passedIsClosingCheck + pq.signalClosing() // closes done; does NOT close queue + close(shutdownDone) + }() + + wg.Wait() + + if panicked { + t.Errorf("iteration %d: unexpected panic — queue must not be closed in the fixed path", i) + } + }() + + if t.Failed() { + return + } + } + + t.Logf("confirmed: zero panics in %d iterations with the fix applied", iterations) +} + +// ============================================================================= +// UpdateProvider Lifecycle Tests +// +// These tests verify the three key invariants of the UpdateProvider fix: +// 1. New queue is stored BEFORE signalClosing fires (stale producers re-route) +// 2. Transfer happens BEFORE signalClosing (items go to new workers, not errored) +// 3. Concurrent producers + UpdateProvider produce zero panics +// ============================================================================= + +// TestUpdateProvider_StaleProducerReroutes verifies that a "stale producer" — +// a goroutine that fetched oldPq before UpdateProvider atomically replaced it — +// can transparently re-route to newPq when it later detects isClosing(). +// +// The re-routing logic in tryRequest is: +// +// if pq.isClosing() { +// if newPq, err := bifrost.getProviderQueue(provider); err == nil && newPq != pq { +// pq = newPq // transparent re-route +// } +// } +// +// This test exercises that exact sequence without a full Bifrost instance. +func TestUpdateProvider_StaleProducerReroutes(t *testing.T) { + var requestQueues sync.Map + provider := schemas.OpenAI + + oldPq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + newPq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + // Initial state: requestQueues holds oldPq + requestQueues.Store(provider, oldPq) + + // Stale producer: fetched its reference before UpdateProvider ran + stalePq := oldPq + + // Simulate UpdateProvider steps 2 + 4: + // Step 2: atomically replace — new producers now get newPq + requestQueues.Store(provider, newPq) + // Step 4: signal old closing — stale producers will detect this + oldPq.signalClosing() + + // --- Stale producer detects isClosing and attempts re-route --- + var reroutedPq *ProviderQueue + if stalePq.isClosing() { + if val, ok := requestQueues.Load(provider); ok { + candidate := val.(*ProviderQueue) + if candidate != stalePq { + reroutedPq = candidate + } + } + } + + if reroutedPq == nil { + t.Fatal("stale producer failed to re-route: re-route returned nil (check step ordering)") + } + if reroutedPq != newPq { + t.Fatal("stale producer re-routed to wrong queue: expected newPq") + } + if reroutedPq.isClosing() { + t.Fatal("re-routed queue is already closing — re-route is useless (newPq must be fresh)") + } + + // Verify: sending to re-routed queue succeeds without panic + panicked := false + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + msg := &ChannelMessage{} + select { + case reroutedPq.queue <- msg: + case <-reroutedPq.done: + t.Error("newPq.done fired — newPq should be open") + } + }() + if panicked { + t.Fatal("panic while sending to re-routed queue — queue must not be closed") + } +} + +// TestUpdateProvider_TransferOrdering verifies the ordering invariant: +// items are moved from oldPq to newPq BEFORE signalClosing(oldPq) is called. +// +// Observable consequence: during the entire transfer loop, oldPq.isClosing() +// must remain false. Only after transfer completes does signalClosing fire. +func TestUpdateProvider_TransferOrdering(t *testing.T) { + const numMessages = 8 + + oldPq := &ProviderQueue{ + queue: make(chan *ChannelMessage, numMessages+2), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + newPq := &ProviderQueue{ + queue: make(chan *ChannelMessage, numMessages+2), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + // Pre-fill oldPq — simulates buffered requests at the moment UpdateProvider runs + for i := 0; i < numMessages; i++ { + oldPq.queue <- &ChannelMessage{} + } + + // Invariant check before transfer begins + if oldPq.isClosing() { + t.Fatal("invariant violated: oldPq already closing before transfer begins") + } + + // Perform transfer, mirroring UpdateProvider step 3. + // Record whether isClosing() ever fired during the loop. + closingDuringTransfer := false + transferred := 0 + for { + select { + case msg := <-oldPq.queue: + if oldPq.isClosing() { + closingDuringTransfer = true + } + newPq.queue <- msg + transferred++ + default: + goto transferComplete + } + } +transferComplete: + + if closingDuringTransfer { + t.Error("invariant violated: oldPq was already closing during transfer — " + + "signalClosing must fire AFTER the transfer loop completes") + } + + // NOW signal closing, mirroring UpdateProvider step 4 + oldPq.signalClosing() + + if !oldPq.isClosing() { + t.Error("expected isClosing() == true after signalClosing()") + } + + // All messages must have moved to newPq + if transferred != numMessages { + t.Errorf("expected %d messages transferred, got %d", numMessages, transferred) + } + if len(newPq.queue) != numMessages { + t.Errorf("expected %d messages in newPq after transfer, got %d", numMessages, len(newPq.queue)) + } + if len(oldPq.queue) != 0 { + t.Errorf("expected 0 messages remaining in oldPq after transfer, got %d", len(oldPq.queue)) + } +} + +// TestUpdateProvider_NoPanicConcurrentAccess verifies that concurrent producers +// sending to a queue that is being replaced (UpdateProvider-style) never cause +// a "send on closed channel" panic. +// +// This test directly models the production scenario that triggered the bug: +// many goroutines continuously send to a ProviderQueue while UpdateProvider +// atomically swaps the queue and signals the old one closing. With the fix +// (queue channel is never closed), the select in producers is always safe. +func TestUpdateProvider_NoPanicConcurrentAccess(t *testing.T) { + const ( + numProducers = 10 + numUpdates = 30 + producerRunTime = 300 * time.Millisecond + ) + + var requestQueues sync.Map + provider := schemas.OpenAI + + makePq := func() *ProviderQueue { + return &ProviderQueue{ + queue: make(chan *ChannelMessage, 200), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + } + + initialPq := makePq() + requestQueues.Store(provider, initialPq) + + var panicCount int64 + var transferDropCount int64 + + stop := make(chan struct{}) + var producerWg sync.WaitGroup + + // Drainer: continuously empties queues so producers never block on a full queue + drainStop := make(chan struct{}) + go func() { + for { + select { + case <-drainStop: + return + default: + if val, ok := requestQueues.Load(provider); ok { + pq := val.(*ProviderQueue) + select { + case <-pq.queue: + default: + } + } + runtime.Gosched() + } + } + }() + + // Producers: continuously simulate the tryRequest hot path + for i := 0; i < numProducers; i++ { + producerWg.Add(1) + go func() { + defer producerWg.Done() + for { + select { + case <-stop: + return + default: + } + + val, ok := requestQueues.Load(provider) + if !ok { + runtime.Gosched() + continue + } + pq := val.(*ProviderQueue) + + func() { + defer func() { + if r := recover(); r != nil { + atomic.AddInt64(&panicCount, 1) + } + }() + + // Re-route check (mirrors tryRequest) + if pq.isClosing() { + if newVal, ok2 := requestQueues.Load(provider); ok2 { + if candidate := newVal.(*ProviderQueue); candidate != pq { + pq = candidate + } + } + // If still closing (RemoveProvider path), just return + if pq.isClosing() { + return + } + } + + msg := &ChannelMessage{} + select { + case pq.queue <- msg: + case <-pq.done: + case <-stop: // unblock immediately when the test signals stop + } + }() + + runtime.Gosched() + } + }() + } + + // Updater: repeatedly performs UpdateProvider-style queue replacements + var updaterWg sync.WaitGroup + updaterWg.Add(1) + go func() { + defer updaterWg.Done() + for i := 0; i < numUpdates; i++ { + val, ok := requestQueues.Load(provider) + if !ok { + continue + } + oldPq := val.(*ProviderQueue) + newPq := makePq() + + // Mirror production UpdateProvider step order exactly: + // Step 2: expose newPq first so stale producers can re-route to it + // once they see oldPq is closing. + requestQueues.Store(provider, newPq) + + // Step 3: transfer buffered messages oldPq → newPq. + drain: + for { + select { + case msg := <-oldPq.queue: + select { + case newPq.queue <- msg: + default: + // newPq full during transfer — mirrors production cancel path. + atomic.AddInt64(&transferDropCount, 1) + } + default: + break drain + } + } + + // Step 4: signal closing — producers holding a stale oldPq ref now + // re-route to newPq (already in the map from step 2). + oldPq.signalClosing() + + time.Sleep(5 * time.Millisecond) + } + }() + + time.Sleep(producerRunTime) + close(stop) + close(drainStop) + producerWg.Wait() + updaterWg.Wait() + + if n := atomic.LoadInt64(&panicCount); n > 0 { + t.Errorf("detected %d panic(s) — fix did not eliminate the concurrent-access race", n) + } else { + t.Logf("confirmed: zero panics across %d producers + %d queue replacements over %v", + numProducers, numUpdates, producerRunTime) + } + if drops := atomic.LoadInt64(&transferDropCount); drops > 0 { + t.Logf("note: %d message(s) dropped during transfer (oldPq had >200 buffered items) — does not affect panic correctness", drops) + } +} + +// ============================================================================= +// RemoveProvider Lifecycle Tests +// +// These tests verify the behavioral contract of RemoveProvider: +// 1. signalClosing() blocks new producers (isClosing() → true) +// 2. Buffered items in the queue get "provider is shutting down" errors +// 3. Workers exit cleanly and the WaitGroup reaches zero +// ============================================================================= + +// TestRemoveProvider_BlocksNewProducers verifies that after signalClosing(), +// isClosing() returns true. Producers check this flag before sending and return +// a "provider is shutting down" error rather than trying to enqueue. +func TestRemoveProvider_BlocksNewProducers(t *testing.T) { + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + // Sanity: before shutdown, producers can proceed + if pq.isClosing() { + t.Fatal("isClosing() must be false before RemoveProvider runs") + } + + // RemoveProvider step 2: signal closing + pq.signalClosing() + + // New producers must see isClosing() == true and abort + if !pq.isClosing() { + t.Fatal("isClosing() must be true after signalClosing() (RemoveProvider)") + } + + // done must be closed so any producer blocked in the select unblocks immediately + select { + case <-pq.done: + // correct + default: + t.Fatal("pq.done must be closed after signalClosing() so blocking producers unblock") + } + + // CRITICAL: queue channel must remain OPEN — closing it would cause panics in + // any producer that entered the select before seeing isClosing(). + // With the fix, we NEVER close the queue channel. + panicked := false + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + // A select with done closed always takes the done case — safe, no panic + select { + case pq.queue <- &ChannelMessage{}: + case <-pq.done: + } + }() + if panicked { + t.Fatal("queue channel must stay open after signalClosing() — closing it causes panics") + } +} + +// TestRemoveProvider_BufferedRequestsGetErrors verifies the drain contract: +// items queued BEFORE signalClosing fires must each receive a +// "provider is shutting down" error on their Err channel. No client should be +// left hanging. +// +// This test exercises the drain logic directly — the same code path that +// requestWorker executes in its case <-pq.done: branch — to avoid the +// non-deterministic select race where the normal processing path can pick up +// items before done fires. +func TestRemoveProvider_BufferedRequestsGetErrors(t *testing.T) { + const numBuffered = 8 + + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, numBuffered+5), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + + // Buffer requests — simulates requests already queued when RemoveProvider runs + msgs := make([]*ChannelMessage, numBuffered) + for i := 0; i < numBuffered; i++ { + msgs[i] = newTestChannelMessage(ctx) + pq.queue <- msgs[i] + } + + // RemoveProvider step 2: signal closing + pq.signalClosing() + + // Execute the drain path — exactly what requestWorker does in case <-pq.done: + <-pq.done // fires immediately since signalClosing was already called +drainLoop: + for { + select { + case r := <-pq.queue: + provKey, mod, _ := r.GetRequestFields() + r.Err <- schemas.BifrostError{ + IsBifrostError: false, + Error: &schemas.ErrorField{ + Message: "provider is shutting down", + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: r.RequestType, + Provider: provKey, + ModelRequested: mod, + }, + } + default: + break drainLoop + } + } + + // Every buffered message must have received a shutdown error + for i, msg := range msgs { + select { + case bifrostErr := <-msg.Err: + if bifrostErr.Error == nil { + t.Errorf("message %d: got nil Error field in BifrostError", i) + continue + } + if bifrostErr.Error.Message != "provider is shutting down" { + t.Errorf("message %d: expected 'provider is shutting down', got %q", + i, bifrostErr.Error.Message) + } + if bifrostErr.ExtraFields.Provider != schemas.OpenAI { + t.Errorf("message %d: expected provider %s, got %s", + i, schemas.OpenAI, bifrostErr.ExtraFields.Provider) + } + if bifrostErr.ExtraFields.RequestType != schemas.ChatCompletionRequest { + t.Errorf("message %d: expected requestType %v, got %v", + i, schemas.ChatCompletionRequest, bifrostErr.ExtraFields.RequestType) + } + default: + t.Errorf("message %d: no error received — client would be left hanging indefinitely", i) + } + } +} + +// TestRemoveProvider_WorkerWaitGroupCompletes verifies that after signalClosing(), +// the worker goroutine decrements the WaitGroup and wg.Wait() returns promptly. +// This mirrors what RemoveProvider does: signal, then Wait() before cleanup. +func TestRemoveProvider_WorkerWaitGroupCompletes(t *testing.T) { + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, 10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + var wg sync.WaitGroup + wg.Add(1) + + // Worker goroutine — mirrors requestWorker's WaitGroup contract + go func() { + defer wg.Done() + for { + select { + case r, ok := <-pq.queue: + if !ok { + return + } + _ = r + case <-pq.done: + // Drain remaining (empty in this test) + for { + select { + case <-pq.queue: + default: + return + } + } + } + } + }() + + // Tiny sleep to ensure worker is parked on select before we signal + time.Sleep(10 * time.Millisecond) + + // RemoveProvider step 2: signal closing + pq.signalClosing() + + // RemoveProvider step 3: wait for workers — must complete promptly + waitReturned := make(chan struct{}) + go func() { + wg.Wait() + close(waitReturned) + }() + + select { + case <-waitReturned: + // correct: WaitGroup reached zero after signalClosing() + case <-time.After(2 * time.Second): + t.Fatal("wg.Wait() did not return after signalClosing() — worker is stuck (would deadlock RemoveProvider)") + } +} + +// TestRemoveProvider_ConcurrentNewProducersDuringShutdown verifies that +// concurrent producers trying to enqueue after RemoveProvider calls +// signalClosing() all get safe "provider is shutting down" errors — none panic. +// This tests the TOCTOU window: producer passes isClosing() check, then done fires. +func TestRemoveProvider_ConcurrentNewProducersDuringShutdown(t *testing.T) { + const numProducers = 50 + + pq := &ProviderQueue{ + queue: make(chan *ChannelMessage, numProducers+10), + done: make(chan struct{}), + signalOnce: sync.Once{}, + } + + var panicCount int64 + var shutdownErrors int64 + var successfulSends int64 + + // Gate: all producers start together after isClosing() passes + passedGate := make(chan struct{}) + var gateOnce sync.Once + shutdownFired := make(chan struct{}) + + var producerWg sync.WaitGroup + + for i := 0; i < numProducers; i++ { + producerWg.Add(1) + go func() { + defer producerWg.Done() + defer func() { + if r := recover(); r != nil { + atomic.AddInt64(&panicCount, 1) + } + }() + + // Each producer checks isClosing() first (mirrors tryRequest) + if pq.isClosing() { + atomic.AddInt64(&shutdownErrors, 1) + return + } + + // Signal that at least one producer passed the isClosing() check + gateOnce.Do(func() { close(passedGate) }) + + // Wait for shutdown to be signaled (the TOCTOU window) + <-shutdownFired + + // Producers now enter the select — with the fix, done is closed but + // queue is NOT closed, so this select is always safe (no panic) + msg := &ChannelMessage{} + select { + case pq.queue <- msg: + atomic.AddInt64(&successfulSends, 1) + case <-pq.done: + atomic.AddInt64(&shutdownErrors, 1) + } + }() + } + + // Wait for at least one producer to pass the isClosing() gate + select { + case <-passedGate: + case <-time.After(2 * time.Second): + t.Fatal("no producer passed the isClosing() check within timeout") + } + + // Signal shutdown (RemoveProvider step 2) — this is the TOCTOU race + pq.signalClosing() + close(shutdownFired) + + producerWg.Wait() + + if n := atomic.LoadInt64(&panicCount); n > 0 { + t.Errorf("detected %d panic(s) — queue must not be closed during concurrent shutdown", n) + } + + t.Logf("result: %d successful sends, %d shutdown errors, %d panics across %d producers", + atomic.LoadInt64(&successfulSends), + atomic.LoadInt64(&shutdownErrors), + atomic.LoadInt64(&panicCount), + numProducers) +} diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go index b1be3d9ba9..353ed6005b 100644 --- a/transports/bifrost-http/handlers/providers.go +++ b/transports/bifrost-http/handlers/providers.go @@ -546,7 +546,6 @@ func (h *ProviderHandler) updateProvider(ctx *fasthttp.RequestCtx) { // Attempt model discovery err = h.attemptModelDiscovery(ctx, provider, payload.CustomProviderConfig) - if err != nil { logger.Warn("Model discovery failed for provider %s: %v", provider, err) } @@ -1282,7 +1281,6 @@ func (h *ProviderHandler) attemptModelDiscovery(ctx *fasthttp.RequestCtx, provid defer cancel() _, err := h.modelsManager.ReloadProvider(ctxWithTimeout, provider) - if err != nil { return err }