Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -648,3 +648,4 @@ Systematically address unresolved PR review comments. Uses GraphQL to get unreso
- **Converter functions**: Pure — no side effects, no logging, no HTTP.
- **Pool names**: Descriptive string passed to `pool.New()` (e.g., `"channel-message"`, `"response-stream"`).
- **Context keys**: Use `BifrostContextKey` type. Custom plugins should define their own key types to avoid collisions.
- **Go filenames**: No underscores. The only permitted underscore is the `_test.go` suffix. Examples: `pluginpipeline.go`, `pluginpipeline_test.go` — never `plugin_pipeline.go` or `plugin_pipeline_race_test.go`. Concatenate words (lowercase, no separators) for multi-word filenames.
176 changes: 123 additions & 53 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,12 @@ type PluginPipeline struct {
preHookErrors []error
postHookErrors []error

// Streaming post-hook timing accumulation (for aggregated spans)
// streamingMu guards the streaming post-hook accumulators below. Per-chunk
// writes (accumulatePluginTiming) run in the provider goroutine while the
// end-of-stream finalizer (FinalizeStreamingPostHookSpans) and
// resetPluginPipeline can run in a different goroutine, so unsynchronised
// access triggers "concurrent map read and map write" panics.
streamingMu sync.Mutex
postHookTimings map[string]*pluginTimingAccumulator // keyed by plugin name
postHookPluginOrder []string // order in which post-hooks ran (for nested span creation)
chunkCount int
Expand Down Expand Up @@ -5212,48 +5217,62 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
}
}
}
// Create plugin pipeline for streaming requests outside retry loop to prevent leaks
var postHookRunner schemas.PostHookRunner
var pipeline *PluginPipeline
if IsStreamRequestType(req.RequestType) {
pipeline = bifrost.getPluginPipeline()
postHookRunner = func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, len(*bifrost.llmPlugins.Load()))
if bifrostErr != nil {
return nil, bifrostErr
}
return resp, nil
}
// Store a finalizer callback to create aggregated post-hook spans at stream end.
// Wrapped in sync.Once so the normal end-of-stream invocation and a deferred
// safety-net invocation (e.g. from a provider goroutine's panic path) cannot
// double-release the pipeline.
var finalizerOnce sync.Once
postHookSpanFinalizer := func(ctx context.Context) {
finalizerOnce.Do(func() {
pipeline.FinalizeStreamingPostHookSpans(ctx)
bifrost.releasePluginPipeline(pipeline)
})
}
req.Context.SetValue(schemas.BifrostContextKeyPostHookSpanFinalizer, postHookSpanFinalizer)
}

// Execute request with retries
// Execute request with retries. For streaming, the plugin pipeline,
// postHookRunner, and finalizer are allocated per-attempt inside the
// request handler closure. If they were request-scoped, a retry
// triggered by CheckFirstStreamChunkForError could run against a
// pipeline the previous attempt's provider goroutine has already
// returned to the pool via its deferred finalizer.
if IsStreamRequestType(req.RequestType) {
stream, bifrostError = executeRequestWithRetries(req.Context, config, func() (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) {
return bifrost.handleProviderStreamRequest(provider, req, key, postHookRunner)
pipeline := bifrost.getPluginPipeline()
postHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, len(*bifrost.llmPlugins.Load()))
if bifrostErr != nil {
return nil, bifrostErr
}
return resp, nil
}
// sync.Once guards pipeline release so the normal end-of-stream
// invocation and a deferred safety-net invocation (panic path)
// cannot double-release the pipeline.
var finalizerOnce sync.Once
postHookSpanFinalizer := func(ctx context.Context) {
finalizerOnce.Do(func() {
pipeline.FinalizeStreamingPostHookSpans(ctx)
bifrost.releasePluginPipeline(pipeline)
})
}
req.Context.SetValue(schemas.BifrostContextKeyPostHookSpanFinalizer, postHookSpanFinalizer)

streamCh, streamErr := bifrost.handleProviderStreamRequest(provider, req, key, postHookRunner)
// If stream setup failed before any provider goroutine started,
// no deferred finalizer will run — release the pipeline directly
// so a retry doesn't inherit a leaked pool entry.
if streamErr != nil && streamCh == nil {
finalizerOnce.Do(func() {
bifrost.releasePluginPipeline(pipeline)
})
}
return streamCh, streamErr
}, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger)
} else {
result, bifrostError = executeRequestWithRetries(req.Context, config, func() (*schemas.BifrostResponse, *schemas.BifrostError) {
return bifrost.handleProviderRequest(provider, config, req, key, keys)
}, req.RequestType, provider.GetProviderKey(), model, &req.BifrostRequest, bifrost.logger)
}

// Release pipeline immediately for non-streaming requests only
// For streaming, the pipeline is released in the postHookSpanFinalizer after streaming completes
// Exception: if streaming request has an error, release immediately since finalizer won't be called
if pipeline != nil && (!IsStreamRequestType(req.RequestType) || bifrostError != nil) {
bifrost.releasePluginPipeline(pipeline)
// For streaming with an error, route release through the ctx-registered
// finalizer (wrapped in sync.Once) so we don't double-Put into the pool
// or race the provider goroutine's deferred FinalizeStreamingPostHookSpans
// call. The finalizer is always the LAST attempt's finalizer — earlier
// attempts' finalizers have already fired via their provider goroutines'
// defers. For streaming without error, the finalizer is invoked by
// completeDeferredSpan / the provider goroutine's defer.
if IsStreamRequestType(req.RequestType) && bifrostError != nil {
if finalizer, ok := req.Context.Value(schemas.BifrostContextKeyPostHookSpanFinalizer).(func(context.Context)); ok && finalizer != nil {
finalizer(req.Context)
}
}

if bifrostError != nil {
Expand Down Expand Up @@ -5881,7 +5900,9 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche
}
// Increment chunk count for streaming
if isStreaming {
p.streamingMu.Lock()
p.chunkCount++
p.streamingMu.Unlock()
}
// Final logic: if both are set, error takes precedence, unless error is nil
if bifrostErr != nil {
Expand Down Expand Up @@ -6002,19 +6023,38 @@ func (p *PluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *s

// resetPluginPipeline resets a PluginPipeline instance for reuse
func (p *PluginPipeline) resetPluginPipeline() {
// Drop cross-request references while the object sits in the pool.
// getPluginPipeline rebinds all four on acquisition, so nil'ing here
// only affects GC hygiene — important when plugins are hot-swapped.
p.llmPlugins = nil
p.mcpPlugins = nil
p.executedPreHooks = 0
p.preHookErrors = p.preHookErrors[:0]
p.postHookErrors = p.postHookErrors[:0]
// Reset streaming timing accumulation
// Reset streaming timing accumulation under lock — the provider goroutine's
// deferred finalizer may still be iterating these fields when the pipeline
// is returned to the pool. logger/tracer are nilled here too so the write
// is synchronized with the finalizer's read under the same mutex.
p.streamingMu.Lock()
p.logger = nil
p.tracer = nil
p.chunkCount = 0
if p.postHookTimings != nil {
// clear() drops *pluginTimingAccumulator values (freeing them for GC)
// while retaining the map's backing hash table for reuse.
clear(p.postHookTimings)
}
// clear() zeros elements in [0, len) — scrub before [:0] so the backing
// array doesn't retain live string references once the slice is truncated.
clear(p.postHookPluginOrder)
p.postHookPluginOrder = p.postHookPluginOrder[:0]
p.streamingMu.Unlock()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
Comment thread
greptile-apps[bot] marked this conversation as resolved.

// accumulatePluginTiming accumulates timing for a plugin during streaming
func (p *PluginPipeline) accumulatePluginTiming(pluginName string, duration time.Duration, hasError bool) {
p.streamingMu.Lock()
defer p.streamingMu.Unlock()
if p.postHookTimings == nil {
p.postHookTimings = make(map[string]*pluginTimingAccumulator)
}
Expand All @@ -6036,7 +6076,40 @@ func (p *PluginPipeline) accumulatePluginTiming(pluginName string, duration time
// This should be called once at the end of streaming to create one span per plugin with average timing.
// Spans are nested to mirror the pre-hook hierarchy (each post-hook is a child of the previous one).
func (p *PluginPipeline) FinalizeStreamingPostHookSpans(ctx context.Context) {
if p.postHookTimings == nil || len(p.postHookPluginOrder) == 0 {
// Snapshot the accumulators under lock so per-chunk writers in the
// provider goroutine can't race with the finalizer. Tracer calls below
// run unlocked — we don't want to stall chunk writers on span I/O.
type snapshotEntry struct {
pluginName string
totalDuration time.Duration
invocations int
errors int
}
p.streamingMu.Lock()
// Capture tracer under the same lock that guards resetPluginPipeline's
// writes so the read/write pair on p.tracer is synchronized and the
// unlocked tracer calls below use a stable local.
tracer := p.tracer
if tracer == nil || p.postHookTimings == nil || len(p.postHookPluginOrder) == 0 {
p.streamingMu.Unlock()
return
}
snapshot := make([]snapshotEntry, 0, len(p.postHookPluginOrder))
for _, pluginName := range p.postHookPluginOrder {
timing, ok := p.postHookTimings[pluginName]
if !ok || timing.invocations == 0 {
continue
}
snapshot = append(snapshot, snapshotEntry{
pluginName: pluginName,
totalDuration: timing.totalDuration,
invocations: timing.invocations,
errors: timing.errors,
})
}
p.streamingMu.Unlock()

if len(snapshot) == 0 {
return
}

Expand All @@ -6045,50 +6118,47 @@ func (p *PluginPipeline) FinalizeStreamingPostHookSpans(ctx context.Context) {
handle schemas.SpanHandle
hasErrors bool
}
spans := make([]spanInfo, 0, len(p.postHookPluginOrder))
spans := make([]spanInfo, 0, len(snapshot))
currentCtx := ctx

// Start spans in execution order (nested: each is a child of the previous)
for _, pluginName := range p.postHookPluginOrder {
timing, ok := p.postHookTimings[pluginName]
if !ok || timing.invocations == 0 {
continue
}

for _, entry := range snapshot {
// Create span as child of the previous span (nested hierarchy)
newCtx, handle := p.tracer.StartSpan(currentCtx, fmt.Sprintf("plugin.%s.posthook", sanitizeSpanName(pluginName)), schemas.SpanKindPlugin)
newCtx, handle := tracer.StartSpan(currentCtx, fmt.Sprintf("plugin.%s.posthook", sanitizeSpanName(entry.pluginName)), schemas.SpanKindPlugin)
if handle == nil {
continue
}

// Calculate average duration in milliseconds
avgMs := float64(timing.totalDuration.Milliseconds()) / float64(timing.invocations)
avgMs := float64(entry.totalDuration.Milliseconds()) / float64(entry.invocations)

// Set aggregated attributes
p.tracer.SetAttribute(handle, schemas.AttrPluginInvocations, timing.invocations)
p.tracer.SetAttribute(handle, schemas.AttrPluginAvgDurationMs, avgMs)
p.tracer.SetAttribute(handle, schemas.AttrPluginTotalDurationMs, timing.totalDuration.Milliseconds())
tracer.SetAttribute(handle, schemas.AttrPluginInvocations, entry.invocations)
tracer.SetAttribute(handle, schemas.AttrPluginAvgDurationMs, avgMs)
tracer.SetAttribute(handle, schemas.AttrPluginTotalDurationMs, entry.totalDuration.Milliseconds())

if timing.errors > 0 {
p.tracer.SetAttribute(handle, schemas.AttrPluginErrorCount, timing.errors)
if entry.errors > 0 {
tracer.SetAttribute(handle, schemas.AttrPluginErrorCount, entry.errors)
}

spans = append(spans, spanInfo{handle: handle, hasErrors: timing.errors > 0})
spans = append(spans, spanInfo{handle: handle, hasErrors: entry.errors > 0})
currentCtx = newCtx
}

// End spans in reverse order (innermost first, like unwinding a call stack)
for i := len(spans) - 1; i >= 0; i-- {
if spans[i].hasErrors {
p.tracer.EndSpan(spans[i].handle, schemas.SpanStatusError, "some invocations failed")
tracer.EndSpan(spans[i].handle, schemas.SpanStatusError, "some invocations failed")
} else {
p.tracer.EndSpan(spans[i].handle, schemas.SpanStatusOk, "")
tracer.EndSpan(spans[i].handle, schemas.SpanStatusOk, "")
}
}
}

// GetChunkCount returns the number of chunks processed during streaming
func (p *PluginPipeline) GetChunkCount() int {
p.streamingMu.Lock()
defer p.streamingMu.Unlock()
return p.chunkCount
}

Expand Down
77 changes: 77 additions & 0 deletions core/pluginpipelinerace_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package bifrost

import (
"context"
"fmt"
"sync"
"testing"
"time"

schemas "github.com/maximhq/bifrost/core/schemas"
)

// TestPluginPipelineStreamingRace reproduces the production panic:
//
// fatal error: concurrent map read and map write
// (*PluginPipeline).FinalizeStreamingPostHookSpans
//
// It hammers accumulatePluginTiming (per-chunk writer) concurrently with
// FinalizeStreamingPostHookSpans (end-of-stream reader) and resetPluginPipeline
// (pool-release writer). Before the streamingMu fix these three paths had no
// synchronisation and the -race detector / runtime map check would trip
// immediately. Run with: go test -race -run PluginPipelineStreamingRace
func TestPluginPipelineStreamingRace(t *testing.T) {
p := &PluginPipeline{
logger: NewDefaultLogger(schemas.LogLevelError),
tracer: &schemas.NoOpTracer{},
}

const writers = 8
const iterations = 2000

var wg sync.WaitGroup

// Per-chunk accumulator writers — simulate multiple plugins accumulating
// timing for every streamed chunk.
for w := 0; w < writers; w++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
pluginName := fmt.Sprintf("plugin-%d", id%3) // a few distinct plugin keys
for i := 0; i < iterations; i++ {
p.accumulatePluginTiming(pluginName, time.Microsecond, i%17 == 0)
}
}(w)
}

// End-of-stream finalizer racing with writers.
wg.Add(1)
go func() {
defer wg.Done()
ctx := context.Background()
for i := 0; i < iterations/10; i++ {
p.FinalizeStreamingPostHookSpans(ctx)
}
}()

// resetPluginPipeline racing with writers — simulates the pool returning
// the pipeline to another request mid-flight.
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < iterations/10; i++ {
p.resetPluginPipeline()
}
}()

// Concurrent GetChunkCount readers.
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < iterations; i++ {
_ = p.GetChunkCount()
}
}()
Comment thread
akshaydeo marked this conversation as resolved.

wg.Wait()
}
Loading