diff --git a/AGENTS.md b/AGENTS.md index 03fdd812ee..0a4ff1632b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -648,3 +648,4 @@ Systematically address unresolved PR review comments. Uses GraphQL to get unreso - **Converter functions**: Pure — no side effects, no logging, no HTTP. - **Pool names**: Descriptive string passed to `pool.New()` (e.g., `"channel-message"`, `"response-stream"`). - **Context keys**: Use `BifrostContextKey` type. Custom plugins should define their own key types to avoid collisions. +- **Go filenames**: No underscores. The only permitted underscore is the `_test.go` suffix. Examples: `pluginpipeline.go`, `pluginpipeline_test.go` — never `plugin_pipeline.go` or `plugin_pipeline_race_test.go`. Concatenate words (lowercase, no separators) for multi-word filenames. diff --git a/core/bifrost.go b/core/bifrost.go index 70454a96e6..3eac3c9448 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -170,7 +170,12 @@ type PluginPipeline struct { preHookErrors []error postHookErrors []error - // Streaming post-hook timing accumulation (for aggregated spans) + // streamingMu guards the streaming post-hook accumulators below. Per-chunk + // writes (accumulatePluginTiming) run in the provider goroutine while the + // end-of-stream finalizer (FinalizeStreamingPostHookSpans) and + // resetPluginPipeline can run in a different goroutine, so unsynchronised + // access triggers "concurrent map read and map write" panics. + streamingMu sync.Mutex postHookTimings map[string]*pluginTimingAccumulator // keyed by plugin name postHookPluginOrder []string // order in which post-hooks ran (for nested span creation) chunkCount int @@ -5212,36 +5217,44 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } } } - // Create plugin pipeline for streaming requests outside retry loop to prevent leaks - var postHookRunner schemas.PostHookRunner - var pipeline *PluginPipeline - if IsStreamRequestType(req.RequestType) { - pipeline = bifrost.getPluginPipeline() - postHookRunner = func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { - resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, len(*bifrost.llmPlugins.Load())) - if bifrostErr != nil { - return nil, bifrostErr - } - return resp, nil - } - // Store a finalizer callback to create aggregated post-hook spans at stream end. - // Wrapped in sync.Once so the normal end-of-stream invocation and a deferred - // safety-net invocation (e.g. from a provider goroutine's panic path) cannot - // double-release the pipeline. - var finalizerOnce sync.Once - postHookSpanFinalizer := func(ctx context.Context) { - finalizerOnce.Do(func() { - pipeline.FinalizeStreamingPostHookSpans(ctx) - bifrost.releasePluginPipeline(pipeline) - }) - } - req.Context.SetValue(schemas.BifrostContextKeyPostHookSpanFinalizer, postHookSpanFinalizer) - } - - // Execute request with retries + // Execute request with retries. For streaming, the plugin pipeline, + // postHookRunner, and finalizer are allocated per-attempt inside the + // request handler closure. If they were request-scoped, a retry + // triggered by CheckFirstStreamChunkForError could run against a + // pipeline the previous attempt's provider goroutine has already + // returned to the pool via its deferred finalizer. if IsStreamRequestType(req.RequestType) { stream, bifrostError = executeRequestWithRetries(req.Context, config, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { - return bifrost.handleProviderStreamRequest(provider, req, key, postHookRunner) + pipeline := bifrost.getPluginPipeline() + postHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, len(*bifrost.llmPlugins.Load())) + if bifrostErr != nil { + return nil, bifrostErr + } + return resp, nil + } + // sync.Once guards pipeline release so the normal end-of-stream + // invocation and a deferred safety-net invocation (panic path) + // cannot double-release the pipeline. + var finalizerOnce sync.Once + postHookSpanFinalizer := func(ctx context.Context) { + finalizerOnce.Do(func() { + pipeline.FinalizeStreamingPostHookSpans(ctx) + bifrost.releasePluginPipeline(pipeline) + }) + } + req.Context.SetValue(schemas.BifrostContextKeyPostHookSpanFinalizer, postHookSpanFinalizer) + + streamCh, streamErr := bifrost.handleProviderStreamRequest(provider, req, key, postHookRunner) + // If stream setup failed before any provider goroutine started, + // no deferred finalizer will run — release the pipeline directly + // so a retry doesn't inherit a leaked pool entry. + if streamErr != nil && streamCh == nil { + finalizerOnce.Do(func() { + bifrost.releasePluginPipeline(pipeline) + }) + } + return streamCh, streamErr }, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger) } else { result, bifrostError = executeRequestWithRetries(req.Context, config, func() (*schemas.BifrostResponse, *schemas.BifrostError) { @@ -5249,11 +5262,17 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas }, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger) } - // Release pipeline immediately for non-streaming requests only - // For streaming, the pipeline is released in the postHookSpanFinalizer after streaming completes - // Exception: if streaming request has an error, release immediately since finalizer won't be called - if pipeline != nil && (!IsStreamRequestType(req.RequestType) || bifrostError != nil) { - bifrost.releasePluginPipeline(pipeline) + // For streaming with an error, route release through the ctx-registered + // finalizer (wrapped in sync.Once) so we don't double-Put into the pool + // or race the provider goroutine's deferred FinalizeStreamingPostHookSpans + // call. The finalizer is always the LAST attempt's finalizer — earlier + // attempts' finalizers have already fired via their provider goroutines' + // defers. For streaming without error, the finalizer is invoked by + // completeDeferredSpan / the provider goroutine's defer. + if IsStreamRequestType(req.RequestType) && bifrostError != nil { + if finalizer, ok := req.Context.Value(schemas.BifrostContextKeyPostHookSpanFinalizer).(func(context.Context)); ok && finalizer != nil { + finalizer(req.Context) + } } if bifrostError != nil { @@ -5881,7 +5900,9 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche } // Increment chunk count for streaming if isStreaming { + p.streamingMu.Lock() p.chunkCount++ + p.streamingMu.Unlock() } // Final logic: if both are set, error takes precedence, unless error is nil if bifrostErr != nil { @@ -6002,19 +6023,38 @@ func (p *PluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *s // resetPluginPipeline resets a PluginPipeline instance for reuse func (p *PluginPipeline) resetPluginPipeline() { + // Drop cross-request references while the object sits in the pool. + // getPluginPipeline rebinds all four on acquisition, so nil'ing here + // only affects GC hygiene — important when plugins are hot-swapped. + p.llmPlugins = nil + p.mcpPlugins = nil p.executedPreHooks = 0 p.preHookErrors = p.preHookErrors[:0] p.postHookErrors = p.postHookErrors[:0] - // Reset streaming timing accumulation + // Reset streaming timing accumulation under lock — the provider goroutine's + // deferred finalizer may still be iterating these fields when the pipeline + // is returned to the pool. logger/tracer are nilled here too so the write + // is synchronized with the finalizer's read under the same mutex. + p.streamingMu.Lock() + p.logger = nil + p.tracer = nil p.chunkCount = 0 if p.postHookTimings != nil { + // clear() drops *pluginTimingAccumulator values (freeing them for GC) + // while retaining the map's backing hash table for reuse. clear(p.postHookTimings) } + // clear() zeros elements in [0, len) — scrub before [:0] so the backing + // array doesn't retain live string references once the slice is truncated. + clear(p.postHookPluginOrder) p.postHookPluginOrder = p.postHookPluginOrder[:0] + p.streamingMu.Unlock() } // accumulatePluginTiming accumulates timing for a plugin during streaming func (p *PluginPipeline) accumulatePluginTiming(pluginName string, duration time.Duration, hasError bool) { + p.streamingMu.Lock() + defer p.streamingMu.Unlock() if p.postHookTimings == nil { p.postHookTimings = make(map[string]*pluginTimingAccumulator) } @@ -6036,7 +6076,40 @@ func (p *PluginPipeline) accumulatePluginTiming(pluginName string, duration time // This should be called once at the end of streaming to create one span per plugin with average timing. // Spans are nested to mirror the pre-hook hierarchy (each post-hook is a child of the previous one). func (p *PluginPipeline) FinalizeStreamingPostHookSpans(ctx context.Context) { - if p.postHookTimings == nil || len(p.postHookPluginOrder) == 0 { + // Snapshot the accumulators under lock so per-chunk writers in the + // provider goroutine can't race with the finalizer. Tracer calls below + // run unlocked — we don't want to stall chunk writers on span I/O. + type snapshotEntry struct { + pluginName string + totalDuration time.Duration + invocations int + errors int + } + p.streamingMu.Lock() + // Capture tracer under the same lock that guards resetPluginPipeline's + // writes so the read/write pair on p.tracer is synchronized and the + // unlocked tracer calls below use a stable local. + tracer := p.tracer + if tracer == nil || p.postHookTimings == nil || len(p.postHookPluginOrder) == 0 { + p.streamingMu.Unlock() + return + } + snapshot := make([]snapshotEntry, 0, len(p.postHookPluginOrder)) + for _, pluginName := range p.postHookPluginOrder { + timing, ok := p.postHookTimings[pluginName] + if !ok || timing.invocations == 0 { + continue + } + snapshot = append(snapshot, snapshotEntry{ + pluginName: pluginName, + totalDuration: timing.totalDuration, + invocations: timing.invocations, + errors: timing.errors, + }) + } + p.streamingMu.Unlock() + + if len(snapshot) == 0 { return } @@ -6045,50 +6118,47 @@ func (p *PluginPipeline) FinalizeStreamingPostHookSpans(ctx context.Context) { handle schemas.SpanHandle hasErrors bool } - spans := make([]spanInfo, 0, len(p.postHookPluginOrder)) + spans := make([]spanInfo, 0, len(snapshot)) currentCtx := ctx // Start spans in execution order (nested: each is a child of the previous) - for _, pluginName := range p.postHookPluginOrder { - timing, ok := p.postHookTimings[pluginName] - if !ok || timing.invocations == 0 { - continue - } - + for _, entry := range snapshot { // Create span as child of the previous span (nested hierarchy) - newCtx, handle := p.tracer.StartSpan(currentCtx, fmt.Sprintf("plugin.%s.posthook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin) + newCtx, handle := tracer.StartSpan(currentCtx, fmt.Sprintf("plugin.%s.posthook", sanitizeSpanName(entry.pluginName)), schemas.SpanKindPlugin) if handle == nil { continue } // Calculate average duration in milliseconds - avgMs := float64(timing.totalDuration.Milliseconds()) / float64(timing.invocations) + avgMs := float64(entry.totalDuration.Milliseconds()) / float64(entry.invocations) // Set aggregated attributes - p.tracer.SetAttribute(handle, schemas.AttrPluginInvocations, timing.invocations) - p.tracer.SetAttribute(handle, schemas.AttrPluginAvgDurationMs, avgMs) - p.tracer.SetAttribute(handle, schemas.AttrPluginTotalDurationMs, timing.totalDuration.Milliseconds()) + tracer.SetAttribute(handle, schemas.AttrPluginInvocations, entry.invocations) + tracer.SetAttribute(handle, schemas.AttrPluginAvgDurationMs, avgMs) + tracer.SetAttribute(handle, schemas.AttrPluginTotalDurationMs, entry.totalDuration.Milliseconds()) - if timing.errors > 0 { - p.tracer.SetAttribute(handle, schemas.AttrPluginErrorCount, timing.errors) + if entry.errors > 0 { + tracer.SetAttribute(handle, schemas.AttrPluginErrorCount, entry.errors) } - spans = append(spans, spanInfo{handle: handle, hasErrors: timing.errors > 0}) + spans = append(spans, spanInfo{handle: handle, hasErrors: entry.errors > 0}) currentCtx = newCtx } // End spans in reverse order (innermost first, like unwinding a call stack) for i := len(spans) - 1; i >= 0; i-- { if spans[i].hasErrors { - p.tracer.EndSpan(spans[i].handle, schemas.SpanStatusError, "some invocations failed") + tracer.EndSpan(spans[i].handle, schemas.SpanStatusError, "some invocations failed") } else { - p.tracer.EndSpan(spans[i].handle, schemas.SpanStatusOk, "") + tracer.EndSpan(spans[i].handle, schemas.SpanStatusOk, "") } } } // GetChunkCount returns the number of chunks processed during streaming func (p *PluginPipeline) GetChunkCount() int { + p.streamingMu.Lock() + defer p.streamingMu.Unlock() return p.chunkCount } diff --git a/core/pluginpipelinerace_test.go b/core/pluginpipelinerace_test.go new file mode 100644 index 0000000000..4d5abe58e0 --- /dev/null +++ b/core/pluginpipelinerace_test.go @@ -0,0 +1,77 @@ +package bifrost + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + schemas "github.com/maximhq/bifrost/core/schemas" +) + +// TestPluginPipelineStreamingRace reproduces the production panic: +// +// fatal error: concurrent map read and map write +// (*PluginPipeline).FinalizeStreamingPostHookSpans +// +// It hammers accumulatePluginTiming (per-chunk writer) concurrently with +// FinalizeStreamingPostHookSpans (end-of-stream reader) and resetPluginPipeline +// (pool-release writer). Before the streamingMu fix these three paths had no +// synchronisation and the -race detector / runtime map check would trip +// immediately. Run with: go test -race -run PluginPipelineStreamingRace +func TestPluginPipelineStreamingRace(t *testing.T) { + p := &PluginPipeline{ + logger: NewDefaultLogger(schemas.LogLevelError), + tracer: &schemas.NoOpTracer{}, + } + + const writers = 8 + const iterations = 2000 + + var wg sync.WaitGroup + + // Per-chunk accumulator writers — simulate multiple plugins accumulating + // timing for every streamed chunk. + for w := 0; w < writers; w++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + pluginName := fmt.Sprintf("plugin-%d", id%3) // a few distinct plugin keys + for i := 0; i < iterations; i++ { + p.accumulatePluginTiming(pluginName, time.Microsecond, i%17 == 0) + } + }(w) + } + + // End-of-stream finalizer racing with writers. + wg.Add(1) + go func() { + defer wg.Done() + ctx := context.Background() + for i := 0; i < iterations/10; i++ { + p.FinalizeStreamingPostHookSpans(ctx) + } + }() + + // resetPluginPipeline racing with writers — simulates the pool returning + // the pipeline to another request mid-flight. + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < iterations/10; i++ { + p.resetPluginPipeline() + } + }() + + // Concurrent GetChunkCount readers. + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + _ = p.GetChunkCount() + } + }() + + wg.Wait() +}